Sunday, April 20, 2014

Watershed image segmentation algorithm with Java

I am very interested in image segmentation, that is why the watershed segmentation caught my attention this time. People are using the watershed algorithm at least in the medical imaging applications, and the F. Meyer's algorithm was mentioned to be "one of the most common" one [1]. It also looked like it is quite easy to implement.

My favorite method for studying and understanding algorithms is to actually implement the algorithm. During the implementation, you really have to understand what is happening in each step. Also the produced images (in image algorithms) during the development process are helpful. It is fun to analyze the errors and what causes the artifacts in the image. So, the result for my studies is here: the watershed image segmentation algorithm implemented with Java.

Watershed algorithm explained (italics are from Wikipedia):


1. "A set of markers, pixels where the flooding shall start, are chosen. Each is given a different label."
  • Find all the minimum points from the image. This means that you need to loop through pixels and find the darkest pixel which has not been already selected and which is not close to already selected minimum pixel. You can find as many minimums you like or only selected count of minimums. These minimums are the lowest points of the "basins" which will be flooded.
  • The label can be for example the index numbering or color values from a known premade palette.

2. "The neighboring pixels of each marked area are inserted into a priority queue with a priority level corresponding to the gray level of the pixel."
  • For each found minimum point, add the surrounding pixels to the priority queue. Set the initial labels to the label table (result image). Sort the priority queue from dark to light pixels. Now you can start the flooding.

3. "The pixel with the highest priority level is extracted from the priority queue. If the neighbors of the extracted pixel that have already been labeled all have the same label, then the pixel is labeled with their label. All non-marked neighbors that are not yet in the priority queue are put into the priority queue." 
  • Take out the first pixel (the darkest one) from the priority queue. Check the neighboring pixels for that pixel. If all the labeled pixels are having the same label, label the current pixel with the same label; the current pixel is "inside" of the basin. Remember, there has to be at least one labeled pixel; the one that was previously next to the current pixel. 
  • If there are at least two pixels with different label, it means that we have found a boundary between two different flood basins. We can mark this pixel then as an "edge" pixel. 
  • Finally, add all those neighboring pixels that do not have label to the priority queue.

4. "Redo step 3 until the priority queue is empty."
  • Loop until there are no pixels left in the priority queue. When the priority queue is empty, all the pixels from the image are processed.

 

Sample images

 

Image 1. This is the sample image I have been using for testing. I created it myself with GIMP.

Image 2. This is the result after executing the Watershed algorithm to the sample image. Red squares are the found minimums and the white borders are the basins edges.

Image 3. Here are the sample image and the result image combined in a same image. The boundaries have been colored to green.

Notes:

  • In this implementation, the minimums are found by "sweeping" the pixels from top-left to bottom-right. This way the found minimum is always the "first" one with the lowest value. That is why the minimum centers (the red squares) are not in the exact center of the darkest area. 
  • If you rotate the source image, the result will be different, that is because the minimum locations are changing (see the note above) and this discussion [2].
  • After segmentation, you probably need the remove the "background" pixels. What you need to do might change per application. See more examples for fine tuning the Watershed algorithm from http://cmm.ensmp.fr/~beucher/wtshed.html [3].

 

The Java source


Implementation notes:
  • To find more speed in the pixel sorting, I implemented a crude SortedVector, which uses insertion sort like algorithm to keep the inserted pixels (FloodPoints) in sorted order.
  • I am disabling the pixels from a rectangle shaped area around the minimum. A circle shape should work better in some cases.
  • The source contains commented code for saving the image in different steps of the process. Just uncomment the lines to get the images saved on your drive.
  • The source contains commented code for real time view of the process. Just uncomment the parts where the JFrame is accessed to see the final image being built. You should try that, it looks nice! :)
  • You can run the code in 4-connected pixels or 8-connected pixels mode. The results are slightly different. You should try them both.
Please let me know if you spot errors or have any other comments. Thanks!

Usage


java Watershed [source file] [result file] [number of segments, 1-256] [minimum window width, 8-256] [connected pixels, 4|8] 

For example: 
java Watershed sample_image.png result.png 22 60 8

SortedVector.java


import java.util.Vector
public class SortedVector { 
    Vector<Comparable<Object>> data; 
    public SortedVector() {  
        data = new Vector<Comparable<Object>>()
    } 
    public void add(Object o) { 
        if (instanceof Comparable) { 
            insert((Comparable<Object>)o)
        } else { 
            throw new IllegalArgumentException("Object " + 
                    "must implement Comparable interface.")
        } 
    } 
     
    private void insert(Comparable<Object> o) { 
        if (data.size()==0) { 
            data.add(o)
        } 
        int middle = 0; 
        int left  = 0; 
        int right = data.size()-1; 
        while (left < right) { 
            middle  = (left+right)/2; 
            if (data.elementAt(middle).compareTo(o)==-1) { 
                left = middle + 1; 
            } else if (data.elementAt(middle).compareTo(o)==1) { 
                right = middle - 1; 
            } else { 
                // position found, insert here 
                // break out while 
                left = data.size()+1; 
            } 
        } 
        data.add(middle, o)
    } 
    public int size() { 
        return data.size()
    } 
     
    public Object elementAt(int index) { 
        return data.elementAt(index)
    } 
    public Object remove(int position) { 
        return data.remove(position)
    } 
} 

Watershed.java


import java.awt.Color
import java.awt.Graphics
import java.awt.image.BufferedImage
import java.io.File
import java.util.Arrays
import java.util.Vector
import javax.imageio.ImageIO
import javax.swing.JFrame; 
/** 
 * @author tejopa / http://popscan.blogspot.com 
 * @date 2014-04-20 
 */ 
public class Watershed { 
    //JFrame frame; 
    int g_w; 
    int g_h; 
     
    public static void main(String[] args) { 
        if (args.length!=5) { 
            System.out.println("Usage: java popscan.Watershed
                            + " [source image filename]
                            + " [destination image filename]
                            + " [flood point count (1-256)]
                            + " [minimums window width (8-256)]
                            + " [connected pixels (4 or 8)]
                            )
            return
        } 
        String src = args[0]
        String dst = args[1]
        int floodPoints = Integer.parseInt(args[2])
        int windowWidth = Integer.parseInt(args[3])
        int connectedPixels = Integer.parseInt(args[4])
         
        Watershed watershed = new Watershed()
         
        long start = System.currentTimeMillis()
        BufferedImage dstImage = watershed.calculate(loadImage(src)
                floodPoints,windowWidth,connectedPixels)
        long end = System.currentTimeMillis()
         
        // save the resulting image 
        long totalms = (end-start)
        System.out.println("Took: "+totalms+" milliseconds")
        saveImage(dst, dstImage)
    } 
     
    private BufferedImage calculate(BufferedImage image,  
            int floodPoints, int windowWidth,  
            int connectedPixels) { 
        /* 
        // frame for real time view for the process 
        frame = new JFrame(); 
        frame.setSize(image.getWidth(),image.getHeight()); 
        frame.setVisible(true); 
        */ 
         
        g_w = image.getWidth()
        g_h = image.getHeight()
        // height map is the gray color image 
        int[] map = image.getRGB(0, 0, g_w, g_h, null, 0, g_w)
        // LUT is the lookup table for the processed pixels 
        int[] lut = new int[g_w*g_h]
        // fill LUT with ones 
        Arrays.fill(lut, 1)
        Vector<FloodPoint> minimums = new Vector<FloodPoint>()
        // loop all the pixels of the image until 
        // a) all the required minimums have been found 
        // OR 
        // b) there are no more unprocessed pixels left  
        int foundMinimums = 0; 
        while (foundMinimums<floodPoints) { 
            int minimumValue = 256; 
            int minimumPosition = -1; 
            for (int i=0;i<lut.length;i++) { 
                if ((lut[i]==1)&&(map[i]<minimumValue)) { 
                    minimumPosition = i; 
                    minimumValue = map[i]
                } 
            } 
            // check if minimum was found 
            if (minimumPosition!=-1) { 
                // add minimum to found minimum vector 
                int x = minimumPosition%g_w; 
                int y = minimumPosition/g_w;  
                int grey = map[x+g_w*y]&0xff; 
                int label = foundMinimums; 
                minimums.add(new FloodPoint(x,y, 
                        label,grey))
                // remove pixels around so that the next minimum 
                // must be at least windowWidth/2 distance from 
                // this minimum (using square, could be circle...) 
                int half = windowWidth/2; 
                fill(x-half,y-half,x+half,y+half,lut,0)
                lut[minimumPosition] = 0; 
                foundMinimums++; 
            } else { 
                // stop while loop 
                System.out.println("Out of pixels. Found " 
                                    + minimums.size() 
                                    + " minimums of requested " 
                                    + floodPoints+".")
                break
            } 
        } 
        /* 
        // create image with minimums only 
        for (int i=0;i<minimums.size();i++) { 
            FloodPoint p = minimums.elementAt(i); 
            Graphics g = image.getGraphics(); 
            g.setColor(Color.red); 
            g.drawRect(p.x, p.y, 2, 2); 
        } 
        saveImage("minimums.png", image); 
        */ 
         
        // start flooding from minimums 
        lut = flood(map,minimums,connectedPixels)
         
        // return flooded image 
        image.setRGB(0, 0, g_w, g_h, lut, 0, g_w)
        /*// create image with boundaries also 
        for (int i=0;i<minimums.size();i++) { 
            FloodPoint p = minimums.elementAt(i)
            Graphics g = image.getGraphics()
            g.setColor(Color.red)
            g.drawRect(p.x, p.y, 2, 2)
        } 
        saveImage("minimums_and_boundaries.png", image)
        */ 
        return image; 
    } 
    private int[] flood(int[] map, Vector<FloodPoint> minimums,  
            int connectedPixels) { 
        SortedVector queue = new SortedVector()
        //BufferedImage result = new BufferedImage(g_w, g_h, 
        //        BufferedImage.TYPE_INT_RGB); 
        int[] lut = new int[g_w*g_h]
        int[] inqueue = new int[g_w*g_h]
        // not processed = -1, processed >= 0 
        Arrays.fill(lut, -1)
        Arrays.fill(inqueue, 0)
        // Initialize queue with each found minimum 
        for (int i=0;i<minimums.size();i++) { 
            FloodPoint p = minimums.elementAt(i)
            int label = p.label; 
            // insert starting pixels around minimums 
            addPoint(queue, inqueue, map, p.x,   p.y-1, label)
            addPoint(queue, inqueue, map, p.x+1, p.y,   label)
            addPoint(queue, inqueue, map, p.x,   p.y+1, label)
            addPoint(queue, inqueue, map, p.x-1, p.y,   label)
            if (connectedPixels==8) { 
                addPoint(queue, inqueue, map, p.x-1, p.y-1, label)
                addPoint(queue, inqueue, map, p.x+1, p.y-1, label)
                addPoint(queue, inqueue, map, p.x+1, p.y+1, label)
                addPoint(queue, inqueue, map, p.x-1, p.y+1, label)
            } 
            int pos = p.x+p.y*g_w; 
            lut[pos] = label; 
            inqueue[pos] = 1; 
        } 
         
        // start flooding 
        while (queue.size()>0) { 
            // find minimum 
            FloodPoint extracted = null; 
            // remove the minimum from the queue 
            extracted = (FloodPoint)queue.remove(0)
            int x = extracted.x; 
            int y = extracted.y; 
            int label = extracted.label; 
            // check pixels around extracted pixel 
            int[] labels = new int[connectedPixels]
            labels[0] = getLabel(lut,x,y-1)
            labels[1] = getLabel(lut,x+1,y)
            labels[2] = getLabel(lut,x,y+1)
            labels[3] = getLabel(lut,x-1,y)
            if (connectedPixels==8) { 
                labels[4] = getLabel(lut,x-1,y-1)
                labels[5] = getLabel(lut,x+1,y-1)
                labels[6] = getLabel(lut,x+1,y+1)
                labels[7] = getLabel(lut,x-1,y+1)
            } 
            boolean onEdge = isEdge(labels,extracted)
            if (onEdge) { 
                // leave edges without label 
            } else { 
                // set pixel with label 
                lut[x+g_w*y] = extracted.label; 
            } 
            if (!inQueue(inqueue,x,y-1)) { 
                addPoint(queue, inqueue, map, x, y-1, label)
            } 
            if (!inQueue(inqueue,x+1,y)) { 
                addPoint(queue, inqueue, map, x+1, y, label)
            } 
            if (!inQueue(inqueue,x,y+1)) { 
                addPoint(queue, inqueue, map, x, y+1, label)
            } 
            if (!inQueue(inqueue,x-1,y)) { 
                addPoint(queue, inqueue, map, x-1, y, label)
            } 
            if (connectedPixels==8) { 
                if (!inQueue(inqueue,x-1,y-1)) { 
                    addPoint(queue, inqueue, map, x-1, y-1, label)
                } 
                if (!inQueue(inqueue,x+1,y-1)) { 
                    addPoint(queue, inqueue, map, x+1, y-1, label)
                } 
                if (!inQueue(inqueue,x+1,y+1)) { 
                    addPoint(queue, inqueue, map, x+1, y+1, label)
                } 
                if (!inQueue(inqueue,x-1,y+1)) { 
                    addPoint(queue, inqueue, map, x-1, y+1, label)
                } 
            } 
            // draw the current pixel set to frame, WARNING: slow... 
            //result.setRGB(0, 0, g_w, g_h, lut, 0, g_w); 
            //frame.getGraphics().drawImage(result,  
            //     0, 0, g_w, g_h, null); 
        } 
        return lut; 
    } 
     
    private boolean inQueue(int[] inqueue, int x, int y) { 
        if (x<0||x>=g_w||y<0||y>=g_h) { 
            return false; 
        } 
        if (inqueue[x+g_w*y] == 1) { 
            return true; 
        } 
        return false; 
    } 
    private boolean isEdge(int[] labels, FloodPoint extracted) { 
        for (int i=0;i<labels.length;i++) { 
            if (labels[i]!=extracted.label&&labels[i]!=-1) { 
                return true; 
            } 
        } 
        return false; 
    } 
    private int getLabel(int[] lut, int x, int y) { 
        if (x<0||x>=g_w||y<0||y>=g_h) { 
            return -2; 
        } 
        return lut[x+g_w*y]
    } 
     
    private void addPoint(SortedVector queue,  
            int[] inqueue, int[] map,  
            int x, int y, int label) { 
        if (x<0||x>=g_w||y<0||y>=g_h) { 
            return
        } 
        queue.add(new FloodPoint(x,y,label,map[x+g_w*y]&0xff))
        inqueue[x+g_w*y] = 1; 
    } 
     
    private void fill(int x1, int y1, int x2, int y2,  
            int[] array, int value) { 
        for (int y=y1;y<y2;y++) { 
            for (int x=x1;x<x2;x++) { 
                // clip to boundaries 
                if (y>=0&&x>=0&&y<g_h&&x<g_w) { 
                    array[x+g_w*y] = value; 
                } 
            } 
        } 
    } 
     
    public static void saveImage(String filename,  
            BufferedImage image) { 
        File file = new File(filename)
        try { 
            ImageIO.write(image, "png", file)
        } catch (Exception e) { 
            System.out.println(e.toString()+" Image '"+filename 
                                +"' saving failed.")
        } 
    } 
    public static BufferedImage loadImage(String filename) { 
        BufferedImage result = null; 
        try { 
            result = ImageIO.read(new File(filename))
        } catch (Exception e) { 
            System.out.println(e.toString()+" Image '" 
                                +filename+"' not found.")
        } 
        return result; 
    } 
     
    class FloodPoint implements Comparable<Object{ 
        int x; 
        int y; 
        int label; 
        int grey; 
         
        public FloodPoint(int x, int y, int label, int grey) { 
            this.x = x; 
            this.y = y; 
            this.label = label; 
            this.grey = grey; 
        } 
        @Override 
        public int compareTo(Object o) { 
            FloodPoint other = (FloodPoint)o; 
            if (this.grey < other.grey ) { 
                return -1; 
            } else if (this.grey > other.grey ) { 
                return 1; 
            } 
            return 0; 
        } 
    } 
     
} 

Sources:
[1] http://en.wikipedia.org/wiki/Watershed_%28image_processing%29
[2] http://imagej.1557.x6.nabble.com/Watershed-Algorithm-source-bug-td3685192.html
[3] http://cmm.ensmp.fr/~beucher/wtshed.html