Sunday, June 16, 2013

Image segmentation with K-means algorithm - Java implementation

K-means algorithm is a well known clustering algorithm and it is described in many papers and online texts. But even after reading many of these said documents, I was confused and still had many questions in my mind.

I was wondering questions like:
  • What are the K cluster centers in digital image case?
  • What is the minimal distance function when comparing clusters?
I left these questions in my mind for few days, and finally on one morning I realized what the K-means algorithm means in a context of digital image segmentation. Segmentation in this case means "finding the continuous areas where the color and intensity is similar to adjacent pixels".

I will try to explain the questions above in this post.

K-means algorithm for digital images in a nut shell
  1. create k clusters, k being the cluster count and the "center" being a color value
  2. for each pixel, find the cluster which has minimal distance to the pixel
  3. if the pixel was already in other cluster, remove the pixel from other cluster and add it to a new cluster
  4. when pixel is added to a cluster, adjust the cluster center color value by adding the pixel colors to it, remove the pixel color values from the old cluster
  5. loop back to 2. until there are no pixels that are changing clusters
The algorithm works like magic! By adding and removing the pixels from clusters, the clusters will transform automatically close to the optimal solution (but the result may not be the optional solution).

K-means cluster centers in digital images 

In K-means algorithm you have to decide how many clusters you will have in your image. The simplest case is that the cluster count means the color count and the cluster centers are the color values. In other words, if you have an RGB image with millions of colors, after K-means clustering with value 20, you will have the image converted to a version which only has 20 colors.

Take a look at the histograms below. Now, choosing the value 20 and running the K-means algorithm on the image, the colors are reduced to 20 colors. The histogram after running the algorithm is on below.

Histogram from original image
Histogram after K-means algorithm

In my algorithm implementation, I am choosing the cluster centers always in the same way, so the result for the same original image is always the same. But if you choose the cluster centers randomly, you will end up having slightly different results.

Distance function when comparing clusters

If you chose to continue with simplest case where cluster centers are simply color values, the compare function is also simple. Just calculate how far, or what is the distance, between the pixel you are working with and the cluster center color is.

When working with gray level images the distance calculation is easy, but for color images you need to split the RGB values to individual red, green and blue values. Then calculate the difference between each pixel color and cluster color and finally average the result.

My implementation

I have implemented two different approaches for adding pixels to the clusters. I have named them "continuous" and "iterative" clustering.

In "continuous clustering" I am adding and removing pixels to and from clusters for each pixel if necessary, and I am also counting the cluster center again after each added or removed pixel. The algorithm finds solution faster than the iterative clustering.

In "iterative clustering" I am adding all the pixels to clusters first, and after that I am calculating the new cluster centers.

There are small differences between the results of these two methods. I _guess_ the differences are coming from integer rounding.

Sample images

The original test image

Result of continuous clustering (took 6 loops, 816 milliseconds)
Result of iterative clustering (took 62 loops, 9158 milliseconds)

Java implementation


package popscan; 
import java.awt.image.BufferedImage
import java.io.File
import java.util.Arrays
import javax.imageio.ImageIO
public class KMeans { 
    BufferedImage original; 
    BufferedImage result; 
    Cluster[] clusters; 
    public static final int MODE_CONTINUOUS = 1; 
    public static final int MODE_ITERATIVE = 2; 
     
     
    public static void main(String[] args) { 
        if (args.length!=4) { 
            System.out.println("Usage: java popscan.KMeans" 
                        + " [source image filename]
                        + " [destination image filename]
                        + " [clustercount 0-255]
                        + " [mode -i (ITERATIVE)|-c (CONTINUOS)]")
            return
        } 
        // parse arguments 
        String src = args[0]
        String dst = args[1]
        int k = Integer.parseInt(args[2])
        String m = args[3]
        int mode = 1; 
        if (m.equals("-i")) { 
            mode = MODE_ITERATIVE
        } else if (m.equals("-c")) { 
            mode = MODE_CONTINUOUS
        } 
         
        // create new KMeans object 
        KMeans kmeans = new KMeans()
        // call the function to actually start the clustering 
        BufferedImage dstImage = kmeans.calculate(loadImage(src)
                                                    k,mode)
        // save the resulting image 
        saveImage(dst, dstImage)
    } 
     
    public KMeans() {    } 
     
    public BufferedImage calculate(BufferedImage image,  
                                            int k, int mode) { 
        long start = System.currentTimeMillis()
        int w = image.getWidth()
        int h = image.getHeight()
        // create clusters 
        clusters = createClusters(image,k)
        // create cluster lookup table 
        int[] lut = new int[w*h]
        Arrays.fill(lut, -1)
         
        // at first loop all pixels will move their clusters 
        boolean pixelChangedCluster = true; 
        // loop until all clusters are stable! 
        int loops = 0; 
        while (pixelChangedCluster) { 
            pixelChangedCluster = false; 
            loops++; 
            for (int y=0;y<h;y++) { 
                for (int x=0;x<w;x++) { 
                    int pixel = image.getRGB(x, y)
                    Cluster cluster = findMinimalCluster(pixel)
                    if (lut[w*y+x]!=cluster.getId()) { 
                        // cluster changed 
                        if (mode==MODE_CONTINUOUS) { 
                            if (lut[w*y+x]!=-1) { 
                                // remove from possible previous  
                                // cluster 
                                clusters[lut[w*y+x]].removePixel( 
                                                            pixel)
                            } 
                            // add pixel to cluster 
                            cluster.addPixel(pixel)
                        } 
                        // continue looping  
                        pixelChangedCluster = true; 
                     
                        // update lut 
                        lut[w*y+x] = cluster.getId()
                    } 
                } 
            } 
            if (mode==MODE_ITERATIVE) { 
                // update clusters 
                for (int i=0;i<clusters.length;i++) { 
                    clusters[i].clear()
                } 
                for (int y=0;y<h;y++) { 
                    for (int x=0;x<w;x++) { 
                        int clusterId = lut[w*y+x]
                        // add pixels to cluster 
                        clusters[clusterId].addPixel( 
                                            image.getRGB(x, y))
                    } 
                } 
            } 
             
        } 
        // create result image 
        BufferedImage result = new BufferedImage(w, h,  
                                    BufferedImage.TYPE_INT_RGB)
        for (int y=0;y<h;y++) { 
            for (int x=0;x<w;x++) { 
                int clusterId = lut[w*y+x]
                result.setRGB(x, y, clusters[clusterId].getRGB())
            } 
        } 
        long end = System.currentTimeMillis()
        System.out.println("Clustered to "+k 
                            + " clusters in "+loops 
                            +" loops in "+(end-start)+" ms.")
        return result; 
    } 
     
    public Cluster[] createClusters(BufferedImage image, int k) { 
        // Here the clusters are taken with specific steps, 
        // so the result looks always same with same image. 
        // You can randomize the cluster centers, if you like. 
        Cluster[] result = new Cluster[k]
        int x = 0; int y = 0; 
        int dx = image.getWidth()/k; 
        int dy = image.getHeight()/k; 
        for (int i=0;i<k;i++) { 
            result[i] = new Cluster(i,image.getRGB(x, y))
            x+=dx; y+=dy; 
        } 
        return result; 
    } 
     
    public Cluster findMinimalCluster(int rgb) { 
        Cluster cluster = null; 
        int min = Integer.MAX_VALUE
        for (int i=0;i<clusters.length;i++) { 
            int distance = clusters[i].distance(rgb)
            if (distance<min) { 
                min = distance; 
                cluster = clusters[i]
            } 
        } 
        return cluster; 
    } 
     
    public static void saveImage(String filename,  
            BufferedImage image) { 
        File file = new File(filename)
        try { 
            ImageIO.write(image, "png", file)
        } catch (Exception e) { 
            System.out.println(e.toString()+" Image '"+filename 
                                +"' saving failed.")
        } 
    } 
     
    public static BufferedImage loadImage(String filename) { 
        BufferedImage result = null; 
        try { 
            result = ImageIO.read(new File(filename))
        } catch (Exception e) { 
            System.out.println(e.toString()+" Image '" 
                                +filename+"' not found.")
        } 
        return result; 
    } 
     
    class Cluster { 
        int id; 
        int pixelCount; 
        int red; 
        int green; 
        int blue; 
        int reds; 
        int greens; 
        int blues; 
         
        public Cluster(int id, int rgb) { 
            int r = rgb>>16&0x000000FF;  
            int g = rgb>> 8&0x000000FF;  
            int b = rgb>> 0&0x000000FF;  
            red = r; 
            green = g; 
            blue = b; 
            this.id = id; 
            addPixel(rgb)
        } 
         
        public void clear() { 
            red = 0; 
            green = 0; 
            blue = 0; 
            reds = 0; 
            greens = 0; 
            blues = 0; 
            pixelCount = 0; 
        } 
         
        int getId() { 
            return id; 
        } 
         
        int getRGB() { 
            int r = reds / pixelCount; 
            int g = greens / pixelCount; 
            int b = blues / pixelCount; 
            return 0xff000000|r<<16|g<<8|b; 
        } 
        void addPixel(int color) { 
            int r = color>>16&0x000000FF;  
            int g = color>> 8&0x000000FF;  
            int b = color>> 0&0x000000FF;  
            reds+=r; 
            greens+=g; 
            blues+=b; 
            pixelCount++; 
            red   = reds/pixelCount; 
            green = greens/pixelCount; 
            blue  = blues/pixelCount; 
        } 
         
        void removePixel(int color) { 
            int r = color>>16&0x000000FF;  
            int g = color>> 8&0x000000FF;  
            int b = color>> 0&0x000000FF;  
            reds-=r; 
            greens-=g; 
            blues-=b; 
            pixelCount--; 
            red   = reds/pixelCount; 
            green = greens/pixelCount; 
            blue  = blues/pixelCount; 
        } 
         
        int distance(int color) { 
            int r = color>>16&0x000000FF;  
            int g = color>> 8&0x000000FF;  
            int b = color>> 0&0x000000FF;  
            int rx = Math.abs(red-r)
            int gx = Math.abs(green-g)
            int bx = Math.abs(blue-b)
            int d = (rx+gx+bx) / 3; 
            return d; 
        } 
    } 
     
} 


Check out my blogpost about watershed segmentation: http://popscan.blogspot.fi/2014/04/watershed-image-segmentation-algorithm.html

Saturday, June 8, 2013

Color image segmentation by thresholding with Java

You may need the image segmentation tool for various image analysis tasks. Different tasks may also need specialized segmentation algorithm, for example the segmentation process may discard all green pixels and keep yellow and red, or if you know the object you are trying to detect contains only blue.

It may be difficult to adjust one general segmentation algorithm for special cases, so you may need to tweak the segmentation code itself, not just the parameters.

I will present here one variation from the simplest segmentation algorithm, segmentation by thresholding, implemented in pure Java.

More sophisticated methods exist, and you should use them if your problem requires so.

In my opinion, this thresholding function gives you a nice starting point when starting to make experiments with segmentation. From the command line, you can only adjust the threshold level, int value 0 - 255, but from the source you can change the connectivity testing and balance the color thresholding if needed.

Example images:
The original photo
Segmented with threshold value 20
Segmented with threshold value 40
Segmented with threshold value 60
Segmented with threshold value 80
Segmented with threshold value 100
Segmented with threshold value 120


Feel free to copy the code and adapt it to your own source and frameworks.

Improvements:
I will make an version from the algorithm where the segment color is averaged from the pixels in the segment. In this version, the color is the color which is the "seed" pixel color.

The algorithm in pseudo code:

for each pixel in image
    if pixel is not in segment
        create new segment
        add the pixel as new "seed" point to list of candidates
        while segment has candidate points
            remove first point from the list of candidates
            if the first candidate point is within threshold limit 

                add the first candidate point to the segment
                add neighbor pixel above to the candidate list
                add neighbor pixel below to the candidate list
                add neighbor pixel right to the candidate list
                add neighbor pixel left  to the candidate list


This kind of algorithm must use the candidate list and while loop instead of recursion, because in the worst case, the whole image (its pixels) belong to one segment. In that case, when using recursion, you will most likely see "StackOverflowError" with big images.

The Java implementation:


package popscan; 
import java.awt.image.BufferedImage
import java.io.File
import java.util.Arrays
import java.util.Vector
import javax.imageio.ImageIO
public class Segmentize { 
    // value for visited pixel 
    int VISITED = 0x00FF0000;         
    // value for not visited pixel 
    int NOT_VISITED = 0x00000000;    
    // source image filename 
    String _srcFilename;     
    // destination image filename 
    String _dstFilename;     
     
    // the source image 
    BufferedImage _srcImage;     
    // the destination image 
    BufferedImage _dstImage;     
     
    // threshold value 0 - 255 
    int _threshold;         
    // image width 
    int _width;             
    // image height 
    int _height;         
    // "seed" color / segment color 
    int _color;             
    // red value from seed 
    int _red;             
    // green value from seed 
    int _green;             
    // blue value from seed 
    int _blue;             
     
    // pixels from source image 
    int[] _pixels;         
    // table for keeping track or visits 
    int[] _visited;         
    // keeping for candidate points 
    Vector<SPoint> _points;  
     
    class SPoint { 
        int x; 
        int y; 
        public SPoint(int x, int y) { 
            this.x = x; 
            this.y = y; 
        } 
    } 
     
    public static void main(String[] args) { 
        if (args.length!=3) { 
            System.out.println("Usage: java Segmentize
                                + " [source image filename]
                                + " [destination image filename]
                                + " [threshold 0-255]")
            return
        } 
        // parse arguments 
        String src = args[0]
        String dst = args[1]
        int threshold = Integer.parseInt(args[2])
         
        // create new Segmentize object 
        Segmentize s = new Segmentize(loadImage(src),threshold)
        // call the function to actually start the segmentation 
        BufferedImage dstImage = s.segmentize()
        // save the resulting image 
        saveImage(dst, dstImage)
    } 
     
    public Segmentize(BufferedImage _srcImage, int threshold) { 
        _threshold = threshold; 
        _width       = _srcImage.getWidth()
        _height       = _srcImage.getHeight()
        // extract pixels from source image 
        _pixels       = _srcImage.getRGB(0, 0, _width, _height, 
                                        null, 0, _width)
        // create empty destination image 
        _dstImage  = new BufferedImage(_width,  
                                        _height,  
                                        BufferedImage.TYPE_INT_RGB)
        _visited   = new int[_pixels.length]
        _points       = new Vector<SPoint>()
    } 
     
    private BufferedImage segmentize() { 
        // initialize points 
        _points.clear()
        // clear table with NOT_VISITED value 
        Arrays.fill(_visited, NOT_VISITED)
        // loop through all pixels 
        for (int x=0;x<_width;x++) { 
            for (int y=0;y<_height;y++) { 
                // if not visited, start new segment 
                if (_visited[_width*y+x]==NOT_VISITED) { 
                    // extract segment color info from pixel 
                    _color = _pixels[_width*y+x]
                    _red   = _color>>16&0xff; 
                    _green = _color>>8&0xff; 
                    _blue  = _color&0xff; 
                    // add "seed" 
                    _points.add(new SPoint(x, y))
                    // start finding neighboring pixels 
                    flood()
                } 
            } 
        } 
        // save the result image 
        _dstImage.setRGB(0, 0, _width, _height, _pixels, 0, _width)
        return _dstImage; 
    } 
     
    public void flood() { 
        // while there are candidates in points vector 
        while (_points.size()>0) { 
            // remove the first candidate 
            SPoint current = _points.remove(0)
            int x = current.x; 
            int y = current.y; 
            if ((x>=0)&&(x<_width)&&(y>=0)&&(y<_height)) { 
                // check if the candidate is NOT_VISITED yet 
                if (_visited[_width*y+x]==NOT_VISITED) { 
                    // extract color info from candidate pixel 
                    int _c = _pixels[_width*y+x]
                    int red   = _c>>16&0xff; 
                    int green = _c>>8&0xff; 
                    int blue  = _c>>0&0xff; 
                    // calculate difference between  
                    // seed's and candidate's 
                    // red, green and blue values 
                    int rx = Math.abs(red - _red)
                    int gx = Math.abs(green - _green)
                    int bx = Math.abs(blue - _blue)
                    // if all colors are under threshold 
                    if (rx<=_threshold 
                            &&gx<=_threshold 
                                &&bx<=_threshold) { 
                        // add the candidate to the segment (image) 
                        _pixels[_width*y+x] = _color; 
                        // mark the candidate as visited 
                        _visited[_width*y+x] = VISITED
                        // add neighboring pixels as candidate 
                        // (8-connected here) 
                        _points.add(new SPoint(x-1,y-1))
                        _points.add(new SPoint(x  ,y-1))
                        _points.add(new SPoint(x+1,y-1))
                        _points.add(new SPoint(x-1,y))
                        _points.add(new SPoint(x+1,y))
                        _points.add(new SPoint(x-1,y+1))
                        _points.add(new SPoint(x  ,y+1))
                        _points.add(new SPoint(x+1,y+1))
                    } 
                } 
            } 
        } 
    } 
     
    public static void saveImage(String filename,  
            BufferedImage image) { 
        File file = new File(filename)
        try { 
            ImageIO.write(image, "png", file)
        } catch (Exception e) { 
            System.out.println(e.toString()+" Image '"+filename 
                                +"' saving failed.")
        } 
    } 
     
    public static BufferedImage loadImage(String filename) { 
        BufferedImage result = null; 
        try { 
            result = ImageIO.read(new File(filename))
        } catch (Exception e) { 
            System.out.println(e.toString()+" Image '" 
                                +filename+"' not found.")
        } 
        return result; 
    }     
     
}