Sunday, June 16, 2013

Image segmentation with K-means algorithm - Java implementation

K-means algorithm is a well known clustering algorithm and it is described in many papers and online texts. But even after reading many of these said documents, I was confused and still had many questions in my mind.

I was wondering questions like:
  • What are the K cluster centers in digital image case?
  • What is the minimal distance function when comparing clusters?
I left these questions in my mind for few days, and finally on one morning I realized what the K-means algorithm means in a context of digital image segmentation. Segmentation in this case means "finding the continuous areas where the color and intensity is similar to adjacent pixels".

I will try to explain the questions above in this post.

K-means algorithm for digital images in a nut shell
  1. create k clusters, k being the cluster count and the "center" being a color value
  2. for each pixel, find the cluster which has minimal distance to the pixel
  3. if the pixel was already in other cluster, remove the pixel from other cluster and add it to a new cluster
  4. when pixel is added to a cluster, adjust the cluster center color value by adding the pixel colors to it, remove the pixel color values from the old cluster
  5. loop back to 2. until there are no pixels that are changing clusters
The algorithm works like magic! By adding and removing the pixels from clusters, the clusters will transform automatically close to the optimal solution (but the result may not be the optional solution).

K-means cluster centers in digital images 

In K-means algorithm you have to decide how many clusters you will have in your image. The simplest case is that the cluster count means the color count and the cluster centers are the color values. In other words, if you have an RGB image with millions of colors, after K-means clustering with value 20, you will have the image converted to a version which only has 20 colors.

Take a look at the histograms below. Now, choosing the value 20 and running the K-means algorithm on the image, the colors are reduced to 20 colors. The histogram after running the algorithm is on below.

Histogram from original image
Histogram after K-means algorithm

In my algorithm implementation, I am choosing the cluster centers always in the same way, so the result for the same original image is always the same. But if you choose the cluster centers randomly, you will end up having slightly different results.

Distance function when comparing clusters

If you chose to continue with simplest case where cluster centers are simply color values, the compare function is also simple. Just calculate how far, or what is the distance, between the pixel you are working with and the cluster center color is.

When working with gray level images the distance calculation is easy, but for color images you need to split the RGB values to individual red, green and blue values. Then calculate the difference between each pixel color and cluster color and finally average the result.

My implementation

I have implemented two different approaches for adding pixels to the clusters. I have named them "continuous" and "iterative" clustering.

In "continuous clustering" I am adding and removing pixels to and from clusters for each pixel if necessary, and I am also counting the cluster center again after each added or removed pixel. The algorithm finds solution faster than the iterative clustering.

In "iterative clustering" I am adding all the pixels to clusters first, and after that I am calculating the new cluster centers.

There are small differences between the results of these two methods. I _guess_ the differences are coming from integer rounding.

Sample images

The original test image

Result of continuous clustering (took 6 loops, 816 milliseconds)
Result of iterative clustering (took 62 loops, 9158 milliseconds)

Java implementation


package popscan; 
import java.awt.image.BufferedImage
import java.io.File
import java.util.Arrays
import javax.imageio.ImageIO
public class KMeans { 
    BufferedImage original; 
    BufferedImage result; 
    Cluster[] clusters; 
    public static final int MODE_CONTINUOUS = 1; 
    public static final int MODE_ITERATIVE = 2; 
     
     
    public static void main(String[] args) { 
        if (args.length!=4) { 
            System.out.println("Usage: java popscan.KMeans" 
                        + " [source image filename]
                        + " [destination image filename]
                        + " [clustercount 0-255]
                        + " [mode -i (ITERATIVE)|-c (CONTINUOS)]")
            return
        } 
        // parse arguments 
        String src = args[0]
        String dst = args[1]
        int k = Integer.parseInt(args[2])
        String m = args[3]
        int mode = 1; 
        if (m.equals("-i")) { 
            mode = MODE_ITERATIVE
        } else if (m.equals("-c")) { 
            mode = MODE_CONTINUOUS
        } 
         
        // create new KMeans object 
        KMeans kmeans = new KMeans()
        // call the function to actually start the clustering 
        BufferedImage dstImage = kmeans.calculate(loadImage(src)
                                                    k,mode)
        // save the resulting image 
        saveImage(dst, dstImage)
    } 
     
    public KMeans() {    } 
     
    public BufferedImage calculate(BufferedImage image,  
                                            int k, int mode) { 
        long start = System.currentTimeMillis()
        int w = image.getWidth()
        int h = image.getHeight()
        // create clusters 
        clusters = createClusters(image,k)
        // create cluster lookup table 
        int[] lut = new int[w*h]
        Arrays.fill(lut, -1)
         
        // at first loop all pixels will move their clusters 
        boolean pixelChangedCluster = true; 
        // loop until all clusters are stable! 
        int loops = 0; 
        while (pixelChangedCluster) { 
            pixelChangedCluster = false; 
            loops++; 
            for (int y=0;y<h;y++) { 
                for (int x=0;x<w;x++) { 
                    int pixel = image.getRGB(x, y)
                    Cluster cluster = findMinimalCluster(pixel)
                    if (lut[w*y+x]!=cluster.getId()) { 
                        // cluster changed 
                        if (mode==MODE_CONTINUOUS) { 
                            if (lut[w*y+x]!=-1) { 
                                // remove from possible previous  
                                // cluster 
                                clusters[lut[w*y+x]].removePixel( 
                                                            pixel)
                            } 
                            // add pixel to cluster 
                            cluster.addPixel(pixel)
                        } 
                        // continue looping  
                        pixelChangedCluster = true; 
                     
                        // update lut 
                        lut[w*y+x] = cluster.getId()
                    } 
                } 
            } 
            if (mode==MODE_ITERATIVE) { 
                // update clusters 
                for (int i=0;i<clusters.length;i++) { 
                    clusters[i].clear()
                } 
                for (int y=0;y<h;y++) { 
                    for (int x=0;x<w;x++) { 
                        int clusterId = lut[w*y+x]
                        // add pixels to cluster 
                        clusters[clusterId].addPixel( 
                                            image.getRGB(x, y))
                    } 
                } 
            } 
             
        } 
        // create result image 
        BufferedImage result = new BufferedImage(w, h,  
                                    BufferedImage.TYPE_INT_RGB)
        for (int y=0;y<h;y++) { 
            for (int x=0;x<w;x++) { 
                int clusterId = lut[w*y+x]
                result.setRGB(x, y, clusters[clusterId].getRGB())
            } 
        } 
        long end = System.currentTimeMillis()
        System.out.println("Clustered to "+k 
                            + " clusters in "+loops 
                            +" loops in "+(end-start)+" ms.")
        return result; 
    } 
     
    public Cluster[] createClusters(BufferedImage image, int k) { 
        // Here the clusters are taken with specific steps, 
        // so the result looks always same with same image. 
        // You can randomize the cluster centers, if you like. 
        Cluster[] result = new Cluster[k]
        int x = 0; int y = 0; 
        int dx = image.getWidth()/k; 
        int dy = image.getHeight()/k; 
        for (int i=0;i<k;i++) { 
            result[i] = new Cluster(i,image.getRGB(x, y))
            x+=dx; y+=dy; 
        } 
        return result; 
    } 
     
    public Cluster findMinimalCluster(int rgb) { 
        Cluster cluster = null; 
        int min = Integer.MAX_VALUE
        for (int i=0;i<clusters.length;i++) { 
            int distance = clusters[i].distance(rgb)
            if (distance<min) { 
                min = distance; 
                cluster = clusters[i]
            } 
        } 
        return cluster; 
    } 
     
    public static void saveImage(String filename,  
            BufferedImage image) { 
        File file = new File(filename)
        try { 
            ImageIO.write(image, "png", file)
        } catch (Exception e) { 
            System.out.println(e.toString()+" Image '"+filename 
                                +"' saving failed.")
        } 
    } 
     
    public static BufferedImage loadImage(String filename) { 
        BufferedImage result = null; 
        try { 
            result = ImageIO.read(new File(filename))
        } catch (Exception e) { 
            System.out.println(e.toString()+" Image '" 
                                +filename+"' not found.")
        } 
        return result; 
    } 
     
    class Cluster { 
        int id; 
        int pixelCount; 
        int red; 
        int green; 
        int blue; 
        int reds; 
        int greens; 
        int blues; 
         
        public Cluster(int id, int rgb) { 
            int r = rgb>>16&0x000000FF;  
            int g = rgb>> 8&0x000000FF;  
            int b = rgb>> 0&0x000000FF;  
            red = r; 
            green = g; 
            blue = b; 
            this.id = id; 
            addPixel(rgb)
        } 
         
        public void clear() { 
            red = 0; 
            green = 0; 
            blue = 0; 
            reds = 0; 
            greens = 0; 
            blues = 0; 
            pixelCount = 0; 
        } 
         
        int getId() { 
            return id; 
        } 
         
        int getRGB() { 
            int r = reds / pixelCount; 
            int g = greens / pixelCount; 
            int b = blues / pixelCount; 
            return 0xff000000|r<<16|g<<8|b; 
        } 
        void addPixel(int color) { 
            int r = color>>16&0x000000FF;  
            int g = color>> 8&0x000000FF;  
            int b = color>> 0&0x000000FF;  
            reds+=r; 
            greens+=g; 
            blues+=b; 
            pixelCount++; 
            red   = reds/pixelCount; 
            green = greens/pixelCount; 
            blue  = blues/pixelCount; 
        } 
         
        void removePixel(int color) { 
            int r = color>>16&0x000000FF;  
            int g = color>> 8&0x000000FF;  
            int b = color>> 0&0x000000FF;  
            reds-=r; 
            greens-=g; 
            blues-=b; 
            pixelCount--; 
            red   = reds/pixelCount; 
            green = greens/pixelCount; 
            blue  = blues/pixelCount; 
        } 
         
        int distance(int color) { 
            int r = color>>16&0x000000FF;  
            int g = color>> 8&0x000000FF;  
            int b = color>> 0&0x000000FF;  
            int rx = Math.abs(red-r)
            int gx = Math.abs(green-g)
            int bx = Math.abs(blue-b)
            int d = (rx+gx+bx) / 3; 
            return d; 
        } 
    } 
     
} 


Check out my blogpost about watershed segmentation: http://popscan.blogspot.fi/2014/04/watershed-image-segmentation-algorithm.html

20 comments:

Anonymous said...

could not find or load main class kmeans

Jussi said...

Hi! Looks like a left the package declaration in the source listing... Check that you have the file in package "popscan" or remove the package declaration from the source.

Unknown said...

hi ...i want to compare 2 image (mouse authentication system )
i want use your code ....while writing out put i am getting blank black image plz help me out

Jussi said...

Hi! I just copy-pasted the code from the post (and removed the package declaration), compiled it and ran it: it worked as it should.

Please, in your code, try to output the image before segmenting it, just to make sure that the input image is as you expect it to be.

Anonymous said...

i tried this code but am not getting the desired result..
the output is the print line in the "main"....and when i omit the if statement from main it gives arrayIndexOutOfBoundException.

Unknown said...

Your program works like a charm. Mixing this with OpenCV is really good too! Honestly, really good work!

Jussi said...

Thank you very much for the feedback! It makes me happy to find out that someone finds this useful. :)

Anonymous said...

do not remove statement you have to pass the argument at run time ...

Anonymous said...

what are the input statements required can you tell me Jussi?

Anonymous said...

Dear sir, thank you so much for posting this. I have been trying to build a k-means image compressor for days for a school project. I'd like you to know that I did not copy and paste a single snippet of your code for this assignment, but you have enabled me to understand many of the finer points especially with regard to how the buffered image in an effort to build a program of my own, which i have been able to complete thanks to you. I appreciate your work.

Anonymous said...

What should be the name of the source file? i am using Eclipse to run the code....how do i obtain the result image?

Jussi said...

The source file name needs to be "KMeans.java".

The package in the source code is set to "popscan", so you need to place the code to "popscan" directory. But if you remove that line, you can place the code in any folder.

You can set the destination file as a parameter. When using eclipse, you need to setup the parameters in the "run configuration".

Anonymous said...

give me a example about this. Program isn't run....
" String src = args[0];
String dst = args[1];
int k = Integer.parseInt(args[2]);
String m = args[3];
int mode = 1;
if (m.equals("-i")) {
mode = MODE_ITERATIVE;
} else if (m.equals("-c")) {
mode = MODE_CONTINUOUS;
} "

Anonymous said...

Hello For my final project I have to study to Implement the k-means algorithm , and there I try to understand yours, please tell me where I have to put the path to the image that I want to test

Peter said...

Hi Jussi, it is working fine... Thanks a lot... very nice...

Unknown said...

Dear Sir

if (args.length!=4) {
System.out.println("Usage: java popscan.KMeans"
+ " [source image filename]"
+ " [destination image filename]"
+ " [clustercount 0-255]"
+ " [mode -i (ITERATIVE)|-c (CONTINUOS)]");
return;
}

only above code is executed.

String src = args[0];
String dst = args[1];
int k = Integer.parseInt(args[2]);

these code is not executed. Because first condition is true and it never goes to the rest part of the code. Please reply how to do this

Helly Patel said...

Hello while executing this program ubuntu java , I am getting error named "could not find or load main class KMeans".Can u pl help me out.

Do I need the package popscan?

Anonymous said...

Hi, please tell me where I have to put the path to the image that I want to test ?
when I executed it , it doesnt give any result :( please i need your help

Jussi said...

Here is how to run the code in Windows.
1. Copy paste the code to text file.
2. Remove the package declaration line.
3. Save the file for example to c:/temp/KMeans.java
4. Open cmd.exe
5. Go to the folder c:/temp/ with command "cd c:/temp"
5. Compile source with command "javac KMeans.java"
6. Copy your source image to the directory c:/temp
7. Execute application with command "java KMeans source.png output.png 25 -i"
8. You will get result like: "Clustered to 25 clusters in 90 loops in 9449 ms."

Anonymous said...

can I use this for indexing images in a CBIR system? does it work?