| Line | Hits | Source |
|---|---|---|
| 1 | /* | |
| 2 | * Copyright (c) 2003, the JUNG Project and the Regents of the University | |
| 3 | * of California | |
| 4 | * All rights reserved. | |
| 5 | * | |
| 6 | * This software is open-source under the BSD license; see either | |
| 7 | * "license.txt" or | |
| 8 | * http://jung.sourceforge.net/license.txt for a description. | |
| 9 | */ | |
| 10 | /* | |
| 11 | * Created on Aug 9, 2004 | |
| 12 | * | |
| 13 | */ | |
| 14 | package edu.uci.ics.jung.algorithms.cluster; | |
| 15 | ||
| 16 | import java.util.Arrays; | |
| 17 | import java.util.Collection; | |
| 18 | import java.util.HashMap; | |
| 19 | import java.util.HashSet; | |
| 20 | import java.util.Iterator; | |
| 21 | import java.util.Map; | |
| 22 | import java.util.Set; | |
| 23 | ||
| 24 | import cern.jet.random.engine.DRand; | |
| 25 | import cern.jet.random.engine.RandomEngine; | |
| 26 | import edu.uci.ics.jung.statistics.DiscreteDistribution; | |
| 27 | ||
| 28 | ||
| 29 | ||
| 30 | /** | |
| 31 | * Groups Objects into a specified number of clusters, based on their | |
| 32 | * proximity in d-dimensional space, using the k-means algorithm. | |
| 33 | * | |
| 34 | * @author Joshua O'Madadhain | |
| 35 | */ | |
| 36 | public class KMeansClusterer | |
| 37 | { | |
| 38 | protected int max_iterations; | |
| 39 | protected double convergence_threshold; | |
| 40 | 1 | protected RandomEngine rand = new DRand(); |
| 41 | ||
| 42 | /** | |
| 43 | * Creates an instance for which calls to <code>cluster</code> will terminate | |
| 44 | * when either of the two following conditions is true: | |
| 45 | * <ul> | |
| 46 | * <li/>the number of iterations is > <code>max_iterations</code> | |
| 47 | * <li/>none of the centroids has moved as much as <code>convergence_threshold</code> | |
| 48 | * since the previous iteration | |
| 49 | * </ul> | |
| 50 | * @param max_iterations | |
| 51 | * @param convergence_threshold | |
| 52 | */ | |
| 53 | public KMeansClusterer(int max_iterations, double convergence_threshold) | |
| 54 | 1 | { |
| 55 | 1 | if (max_iterations < 0) |
| 56 | 0 | throw new IllegalArgumentException("max iterations must be >= 0"); |
| 57 | ||
| 58 | 1 | if (convergence_threshold <= 0) |
| 59 | 0 | throw new IllegalArgumentException("convergence threshold " + |
| 60 | "must be > 0"); | |
| 61 | ||
| 62 | 1 | this.max_iterations = max_iterations; |
| 63 | 1 | this.convergence_threshold = convergence_threshold; |
| 64 | 1 | } |
| 65 | ||
| 66 | /** | |
| 67 | * Returns a <code>Collection</code> of clusters, where each cluster is | |
| 68 | * represented as a <code>Map</code> of <code>Objects</code> to locations | |
| 69 | * in d-dimensional space. | |
| 70 | * @param object_locations a map of the Objects to cluster, to | |
| 71 | * <code>double</code> arrays that specify their locations in d-dimensional space. | |
| 72 | * @param num_clusters the number of clusters to create | |
| 73 | * @throws NotEnoughClustersException | |
| 74 | */ | |
| 75 | public Collection cluster(Map object_locations, int num_clusters) | |
| 76 | { | |
| 77 | 3 | if (num_clusters < 2 || num_clusters > object_locations.size()) |
| 78 | 1 | throw new IllegalArgumentException("number of clusters " + |
| 79 | "must be >= 2 and <= number of objects (" + | |
| 80 | object_locations.size() + ")"); | |
| 81 | ||
| 82 | 2 | if (object_locations == null || object_locations.isEmpty()) |
| 83 | 0 | throw new IllegalArgumentException("'objects' must be non-empty"); |
| 84 | ||
| 85 | 2 | Set centroids = new HashSet(); |
| 86 | 2 | Object[] obj_array = object_locations.keySet().toArray(); |
| 87 | 2 | Set tried = new HashSet(); |
| 88 | ||
| 89 | // create the specified number of clusters | |
| 90 | 12 | while (centroids.size() < num_clusters && tried.size() < object_locations.size()) |
| 91 | { | |
| 92 | 10 | Object o = obj_array[(int)(rand.nextDouble() * obj_array.length)]; |
| 93 | 10 | tried.add(o); |
| 94 | 10 | double[] mean_value = (double[])object_locations.get(o); |
| 95 | 10 | boolean duplicate = false; |
| 96 | 10 | for (Iterator iter = centroids.iterator(); iter.hasNext(); ) |
| 97 | { | |
| 98 | 9 | double[] cur = (double[])iter.next(); |
| 99 | 9 | if (Arrays.equals(mean_value, cur)) |
| 100 | 6 | duplicate = true; |
| 101 | } | |
| 102 | 10 | if (!duplicate) |
| 103 | 4 | centroids.add(mean_value); |
| 104 | } | |
| 105 | ||
| 106 | 2 | if (tried.size() >= object_locations.size()) |
| 107 | 1 | throw new NotEnoughClustersException(); |
| 108 | ||
| 109 | // put items in their initial clusters | |
| 110 | 1 | Map clusterMap = assignToClusters(object_locations, centroids); |
| 111 | ||
| 112 | // keep reconstituting clusters until either | |
| 113 | // (a) membership is stable, or | |
| 114 | // (b) number of iterations passes max_iterations, or | |
| 115 | // (c) max movement of any centroid is <= convergence_threshold | |
| 116 | 1 | int iterations = 0; |
| 117 | 1 | double max_movement = Double.POSITIVE_INFINITY; |
| 118 | 3 | while (iterations++ < max_iterations && max_movement > convergence_threshold) |
| 119 | { | |
| 120 | 2 | max_movement = 0; |
| 121 | 2 | Set new_centroids = new HashSet(); |
| 122 | // calculate new mean for each cluster | |
| 123 | 2 | for (Iterator iter = clusterMap.keySet().iterator(); iter.hasNext(); ) |
| 124 | { | |
| 125 | 4 | double[] centroid = (double[])iter.next(); |
| 126 | 4 | Map elements = (Map)clusterMap.get(centroid); |
| 127 | 4 | double[][] locations = new double[elements.size()][]; |
| 128 | 4 | int i = 0; |
| 129 | 4 | for (Iterator e_iter = elements.keySet().iterator(); e_iter.hasNext(); ) |
| 130 | 10 | locations[i++] = (double[])object_locations.get(e_iter.next()); |
| 131 | ||
| 132 | 4 | double[] mean = DiscreteDistribution.mean(locations); |
| 133 | 4 | max_movement = Math.max(max_movement, |
| 134 | Math.sqrt(DiscreteDistribution.squaredError(centroid, mean))); | |
| 135 | 4 | new_centroids.add(mean); |
| 136 | } | |
| 137 | ||
| 138 | // TODO: check membership of clusters: have they changed? | |
| 139 | ||
| 140 | // regenerate cluster membership based on means | |
| 141 | 2 | clusterMap = assignToClusters(object_locations, new_centroids); |
| 142 | } | |
| 143 | 1 | return (Collection)clusterMap.values(); |
| 144 | } | |
| 145 | ||
| 146 | /** | |
| 147 | * Assigns each object to the cluster whose centroid is closest to the | |
| 148 | * object. | |
| 149 | * @param object_locations a map of objects to locations | |
| 150 | * @param centroids the centroids of the clusters to be formed | |
| 151 | * @return a map of objects to assigned clusters | |
| 152 | */ | |
| 153 | protected Map assignToClusters(Map object_locations, Set centroids) | |
| 154 | { | |
| 155 | 3 | Map clusterMap = new HashMap(); |
| 156 | 3 | for (Iterator c_iter = centroids.iterator(); c_iter.hasNext(); ) |
| 157 | 6 | clusterMap.put(c_iter.next(), new HashMap()); |
| 158 | ||
| 159 | 3 | for (Iterator o_iter = object_locations.keySet().iterator(); o_iter.hasNext(); ) |
| 160 | { | |
| 161 | 15 | Object o = o_iter.next(); |
| 162 | 15 | double[] location = (double[])object_locations.get(o); |
| 163 | ||
| 164 | // find the cluster with the closest centroid | |
| 165 | 15 | Iterator c_iter = centroids.iterator(); |
| 166 | 15 | double[] closest = (double[])c_iter.next(); |
| 167 | 15 | double distance = DiscreteDistribution.squaredError(location, closest); |
| 168 | ||
| 169 | 30 | while (c_iter.hasNext()) |
| 170 | { | |
| 171 | 15 | double[] centroid = (double[])c_iter.next(); |
| 172 | 15 | double dist_cur = DiscreteDistribution.squaredError(location, centroid); |
| 173 | 15 | if (dist_cur < distance) |
| 174 | { | |
| 175 | 7 | distance = dist_cur; |
| 176 | 7 | closest = centroid; |
| 177 | } | |
| 178 | } | |
| 179 | 15 | Map elements = (Map)clusterMap.get(closest); |
| 180 | 15 | elements.put(o, location); |
| 181 | } | |
| 182 | ||
| 183 | 3 | return clusterMap; |
| 184 | } | |
| 185 | ||
| 186 | public void setSeed(int random_seed) | |
| 187 | { | |
| 188 | 0 | this.rand = new DRand(random_seed); |
| 189 | 0 | } |
| 190 | ||
| 191 | /** | |
| 192 | * An exception that indicates that the specified data points cannot be | |
| 193 | * clustered into the number of clusters requested by the user. | |
| 194 | * This will happen if and only if there are fewer distinct points than | |
| 195 | * requested clusters. (If there are fewer total data points than | |
| 196 | * requested clusters, <code>IllegalArgumentException</code> will be thrown.) | |
| 197 | * | |
| 198 | * @author Joshua O'Madadhain | |
| 199 | */ | |
| 200 | public static class NotEnoughClustersException extends RuntimeException | |
| 201 | { | |
| 202 | public String getMessage() | |
| 203 | { | |
| 204 | return "Not enough distinct points in the input data set to form " + | |
| 205 | "the requested number of clusters"; | |
| 206 | } | |
| 207 | } | |
| 208 | } |
|
this report was generated by version 1.0.5 of jcoverage. |
copyright © 2003, jcoverage ltd. all rights reserved. |