My favorite method for studying and understanding algorithms is to actually implement the algorithm. During the implementation, you really have to understand what is happening in each step. Also the produced images (in image algorithms) during the development process are helpful. It is fun to analyze the errors and what causes the artifacts in the image. So, the result for my studies is here: the watershed image segmentation algorithm implemented with Java.
Watershed algorithm explained (italics are from Wikipedia):
1. "A set of markers, pixels where the flooding shall start, are chosen. Each is given a different label."
- Find all the minimum points from the image. This means that you need to loop through pixels and find the darkest pixel which has not been already selected and which is not close to already selected minimum pixel. You can find as many minimums you like or only selected count of minimums. These minimums are the lowest points of the "basins" which will be flooded.
- The label can be for example the index numbering or color values from a known premade palette.
2. "The neighboring pixels of each marked area are inserted into a priority queue with a priority level corresponding to the gray level of the pixel."
- For each found minimum point, add the surrounding pixels to the priority queue. Set the initial labels to the label table (result image). Sort the priority queue from dark to light pixels. Now you can start the flooding.
3. "The pixel with the highest priority level is extracted from the priority queue. If the neighbors of the extracted pixel that have already been labeled all have the same label, then the pixel is labeled with their label. All non-marked neighbors that are not yet in the priority queue are put into the priority queue."
- Take out the first pixel (the darkest one) from the priority queue. Check the neighboring pixels for that pixel. If all the labeled pixels are having the same label, label the current pixel with the same label; the current pixel is "inside" of the basin. Remember, there has to be at least one labeled pixel; the one that was previously next to the current pixel.
- If there are at least two pixels with different label, it means that we have found a boundary between two different flood basins. We can mark this pixel then as an "edge" pixel.
- Finally, add all those neighboring pixels that do not have label to the priority queue.
4. "Redo step 3 until the priority queue is empty."
- Loop until there are no pixels left in the priority queue. When the priority queue is empty, all the pixels from the image are processed.
Sample images
Image 1. This is the sample image I have been using for testing. I created it myself with GIMP.
Image 2. This is the result after executing the Watershed algorithm to the sample
image. Red squares are the found minimums and the white borders are the
basins edges.
Image 3. Here are the sample image and the result image combined in a same image. The boundaries have been colored to green.
Notes:
- In this implementation, the minimums are found by "sweeping" the pixels from top-left to bottom-right. This way the found minimum is always the "first" one with the lowest value. That is why the minimum centers (the red squares) are not in the exact center of the darkest area.
- If you rotate the source image, the result will be different, that is because the minimum locations are changing (see the note above) and this discussion [2].
- After segmentation, you probably need the remove the "background" pixels. What you need to do might change per application. See more examples for fine tuning the Watershed algorithm from http://cmm.ensmp.fr/~beucher/wtshed.html [3].
The Java source
Implementation notes:
- To find more speed in the pixel sorting, I implemented a crude SortedVector, which uses insertion sort like algorithm to keep the inserted pixels (FloodPoints) in sorted order.
- I am disabling the pixels from a rectangle shaped area around the minimum. A circle shape should work better in some cases.
- The source contains commented code for saving the image in different steps of the process. Just uncomment the lines to get the images saved on your drive.
- The source contains commented code for real time view of the process. Just uncomment the parts where the JFrame is accessed to see the final image being built. You should try that, it looks nice! :)
- You can run the code in 4-connected pixels or 8-connected pixels mode. The results are slightly different. You should try them both.
Usage
java Watershed [source file] [result file] [number of segments, 1-256] [minimum window width, 8-256] [connected pixels, 4|8]
For example:
java Watershed sample_image.png result.png 22 60 8
SortedVector.java
import java.util.Vector; |
public class SortedVector { |
Vector<Comparable<Object>> data; |
public SortedVector() { |
data = new Vector<Comparable<Object>>(); |
} |
public void add(Object o) { |
if (o instanceof Comparable) { |
insert((Comparable<Object>)o); |
} else { |
throw new IllegalArgumentException("Object " + |
"must implement Comparable interface."); |
} |
} |
private void insert(Comparable<Object> o) { |
if (data.size()==0) { |
data.add(o); |
} |
int middle = 0; |
int left = 0; |
int right = data.size()-1; |
while (left < right) { |
middle = (left+right)/2; |
if (data.elementAt(middle).compareTo(o)==-1) { |
left = middle + 1; |
} else if (data.elementAt(middle).compareTo(o)==1) { |
right = middle - 1; |
} else { |
// position found, insert here |
// break out while |
left = data.size()+1; |
} |
} |
data.add(middle, o); |
} |
public int size() { |
return data.size(); |
} |
public Object elementAt(int index) { |
return data.elementAt(index); |
} |
public Object remove(int position) { |
return data.remove(position); |
} |
} |
Watershed.java
import java.awt.Color; |
import java.awt.Graphics; |
import java.awt.image.BufferedImage; |
import java.io.File; |
import java.util.Arrays; |
import java.util.Vector; |
import javax.imageio.ImageIO; |
import javax.swing.JFrame; |
/** |
* @author tejopa / http://popscan.blogspot.com |
* @date 2014-04-20 |
*/ |
public class Watershed { |
//JFrame frame; |
int g_w; |
int g_h; |
public static void main(String[] args) { |
if (args.length!=5) { |
System.out.println("Usage: java popscan.Watershed" |
+ " [source image filename]" |
+ " [destination image filename]" |
+ " [flood point count (1-256)]" |
+ " [minimums window width (8-256)]" |
+ " [connected pixels (4 or 8)]" |
); |
return; |
} |
String src = args[0]; |
String dst = args[1]; |
int floodPoints = Integer.parseInt(args[2]); |
int windowWidth = Integer.parseInt(args[3]); |
int connectedPixels = Integer.parseInt(args[4]); |
Watershed watershed = new Watershed(); |
long start = System.currentTimeMillis(); |
BufferedImage dstImage = watershed.calculate(loadImage(src), |
floodPoints,windowWidth,connectedPixels); |
long end = System.currentTimeMillis(); |
// save the resulting image |
long totalms = (end-start); |
System.out.println("Took: "+totalms+" milliseconds"); |
saveImage(dst, dstImage); |
} |
private BufferedImage calculate(BufferedImage image, |
int floodPoints, int windowWidth, |
int connectedPixels) { |
/* |
// frame for real time view for the process |
frame = new JFrame(); |
frame.setSize(image.getWidth(),image.getHeight()); |
frame.setVisible(true); |
*/ |
g_w = image.getWidth(); |
g_h = image.getHeight(); |
// height map is the gray color image |
int[] map = image.getRGB(0, 0, g_w, g_h, null, 0, g_w); |
// LUT is the lookup table for the processed pixels |
int[] lut = new int[g_w*g_h]; |
// fill LUT with ones |
Arrays.fill(lut, 1); |
Vector<FloodPoint> minimums = new Vector<FloodPoint>(); |
// loop all the pixels of the image until |
// a) all the required minimums have been found |
// OR |
// b) there are no more unprocessed pixels left |
int foundMinimums = 0; |
while (foundMinimums<floodPoints) { |
int minimumValue = 256; |
int minimumPosition = -1; |
for (int i=0;i<lut.length;i++) { |
if ((lut[i]==1)&&(map[i]<minimumValue)) { |
minimumPosition = i; |
minimumValue = map[i]; |
} |
} |
// check if minimum was found |
if (minimumPosition!=-1) { |
// add minimum to found minimum vector |
int x = minimumPosition%g_w; |
int y = minimumPosition/g_w; |
int grey = map[x+g_w*y]&0xff; |
int label = foundMinimums; |
minimums.add(new FloodPoint(x,y, |
label,grey)); |
// remove pixels around so that the next minimum |
// must be at least windowWidth/2 distance from |
// this minimum (using square, could be circle...) |
int half = windowWidth/2; |
fill(x-half,y-half,x+half,y+half,lut,0); |
lut[minimumPosition] = 0; |
foundMinimums++; |
} else { |
// stop while loop |
System.out.println("Out of pixels. Found " |
+ minimums.size() |
+ " minimums of requested " |
+ floodPoints+"."); |
break; |
} |
} |
/* |
// create image with minimums only |
for (int i=0;i<minimums.size();i++) { |
FloodPoint p = minimums.elementAt(i); |
Graphics g = image.getGraphics(); |
g.setColor(Color.red); |
g.drawRect(p.x, p.y, 2, 2); |
} |
saveImage("minimums.png", image); |
*/ |
// start flooding from minimums |
lut = flood(map,minimums,connectedPixels); |
// return flooded image |
image.setRGB(0, 0, g_w, g_h, lut, 0, g_w); |
/*// create image with boundaries also |
for (int i=0;i<minimums.size();i++) { |
FloodPoint p = minimums.elementAt(i); |
Graphics g = image.getGraphics(); |
g.setColor(Color.red); |
g.drawRect(p.x, p.y, 2, 2); |
} |
saveImage("minimums_and_boundaries.png", image); |
*/ |
return image; |
} |
private int[] flood(int[] map, Vector<FloodPoint> minimums, |
int connectedPixels) { |
SortedVector queue = new SortedVector(); |
//BufferedImage result = new BufferedImage(g_w, g_h, |
// BufferedImage.TYPE_INT_RGB); |
int[] lut = new int[g_w*g_h]; |
int[] inqueue = new int[g_w*g_h]; |
// not processed = -1, processed >= 0 |
Arrays.fill(lut, -1); |
Arrays.fill(inqueue, 0); |
// Initialize queue with each found minimum |
for (int i=0;i<minimums.size();i++) { |
FloodPoint p = minimums.elementAt(i); |
int label = p.label; |
// insert starting pixels around minimums |
addPoint(queue, inqueue, map, p.x, p.y-1, label); |
addPoint(queue, inqueue, map, p.x+1, p.y, label); |
addPoint(queue, inqueue, map, p.x, p.y+1, label); |
addPoint(queue, inqueue, map, p.x-1, p.y, label); |
if (connectedPixels==8) { |
addPoint(queue, inqueue, map, p.x-1, p.y-1, label); |
addPoint(queue, inqueue, map, p.x+1, p.y-1, label); |
addPoint(queue, inqueue, map, p.x+1, p.y+1, label); |
addPoint(queue, inqueue, map, p.x-1, p.y+1, label); |
} |
int pos = p.x+p.y*g_w; |
lut[pos] = label; |
inqueue[pos] = 1; |
} |
// start flooding |
while (queue.size()>0) { |
// find minimum |
FloodPoint extracted = null; |
// remove the minimum from the queue |
extracted = (FloodPoint)queue.remove(0); |
int x = extracted.x; |
int y = extracted.y; |
int label = extracted.label; |
// check pixels around extracted pixel |
int[] labels = new int[connectedPixels]; |
labels[0] = getLabel(lut,x,y-1); |
labels[1] = getLabel(lut,x+1,y); |
labels[2] = getLabel(lut,x,y+1); |
labels[3] = getLabel(lut,x-1,y); |
if (connectedPixels==8) { |
labels[4] = getLabel(lut,x-1,y-1); |
labels[5] = getLabel(lut,x+1,y-1); |
labels[6] = getLabel(lut,x+1,y+1); |
labels[7] = getLabel(lut,x-1,y+1); |
} |
boolean onEdge = isEdge(labels,extracted); |
if (onEdge) { |
// leave edges without label |
} else { |
// set pixel with label |
lut[x+g_w*y] = extracted.label; |
} |
if (!inQueue(inqueue,x,y-1)) { |
addPoint(queue, inqueue, map, x, y-1, label); |
} |
if (!inQueue(inqueue,x+1,y)) { |
addPoint(queue, inqueue, map, x+1, y, label); |
} |
if (!inQueue(inqueue,x,y+1)) { |
addPoint(queue, inqueue, map, x, y+1, label); |
} |
if (!inQueue(inqueue,x-1,y)) { |
addPoint(queue, inqueue, map, x-1, y, label); |
} |
if (connectedPixels==8) { |
if (!inQueue(inqueue,x-1,y-1)) { |
addPoint(queue, inqueue, map, x-1, y-1, label); |
} |
if (!inQueue(inqueue,x+1,y-1)) { |
addPoint(queue, inqueue, map, x+1, y-1, label); |
} |
if (!inQueue(inqueue,x+1,y+1)) { |
addPoint(queue, inqueue, map, x+1, y+1, label); |
} |
if (!inQueue(inqueue,x-1,y+1)) { |
addPoint(queue, inqueue, map, x-1, y+1, label); |
} |
} |
// draw the current pixel set to frame, WARNING: slow... |
//result.setRGB(0, 0, g_w, g_h, lut, 0, g_w); |
//frame.getGraphics().drawImage(result, |
// 0, 0, g_w, g_h, null); |
} |
return lut; |
} |
private boolean inQueue(int[] inqueue, int x, int y) { |
if (x<0||x>=g_w||y<0||y>=g_h) { |
return false; |
} |
if (inqueue[x+g_w*y] == 1) { |
return true; |
} |
return false; |
} |
private boolean isEdge(int[] labels, FloodPoint extracted) { |
for (int i=0;i<labels.length;i++) { |
if (labels[i]!=extracted.label&&labels[i]!=-1) { |
return true; |
} |
} |
return false; |
} |
private int getLabel(int[] lut, int x, int y) { |
if (x<0||x>=g_w||y<0||y>=g_h) { |
return -2; |
} |
return lut[x+g_w*y]; |
} |
private void addPoint(SortedVector queue, |
int[] inqueue, int[] map, |
int x, int y, int label) { |
if (x<0||x>=g_w||y<0||y>=g_h) { |
return; |
} |
queue.add(new FloodPoint(x,y,label,map[x+g_w*y]&0xff)); |
inqueue[x+g_w*y] = 1; |
} |
private void fill(int x1, int y1, int x2, int y2, |
int[] array, int value) { |
for (int y=y1;y<y2;y++) { |
for (int x=x1;x<x2;x++) { |
// clip to boundaries |
if (y>=0&&x>=0&&y<g_h&&x<g_w) { |
array[x+g_w*y] = value; |
} |
} |
} |
} |
public static void saveImage(String filename, |
BufferedImage image) { |
File file = new File(filename); |
try { |
ImageIO.write(image, "png", file); |
} catch (Exception e) { |
System.out.println(e.toString()+" Image '"+filename |
+"' saving failed."); |
} |
} |
public static BufferedImage loadImage(String filename) { |
BufferedImage result = null; |
try { |
result = ImageIO.read(new File(filename)); |
} catch (Exception e) { |
System.out.println(e.toString()+" Image '" |
+filename+"' not found."); |
} |
return result; |
} |
class FloodPoint implements Comparable<Object> { |
int x; |
int y; |
int label; |
int grey; |
public FloodPoint(int x, int y, int label, int grey) { |
this.x = x; |
this.y = y; |
this.label = label; |
this.grey = grey; |
} |
@Override |
public int compareTo(Object o) { |
FloodPoint other = (FloodPoint)o; |
if (this.grey < other.grey ) { |
return -1; |
} else if (this.grey > other.grey ) { |
return 1; |
} |
return 0; |
} |
} |
} |
Sources:
[1] http://en.wikipedia.org/wiki/Watershed_%28image_processing%29
[2] http://imagej.1557.x6.nabble.com/Watershed-Algorithm-source-bug-td3685192.html
[3] http://cmm.ensmp.fr/~beucher/wtshed.html