|
26 | 26 | import org.apache.commons.math4.exception.MathIllegalArgumentException; |
27 | 27 | import org.apache.commons.math4.exception.NumberIsTooSmallException; |
28 | 28 | import org.apache.commons.math4.exception.util.LocalizedFormats; |
| 29 | +import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer; |
| 30 | +import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer; |
29 | 31 | import org.apache.commons.math4.ml.distance.DistanceMeasure; |
30 | 32 | import org.apache.commons.math4.ml.distance.EuclideanDistance; |
31 | 33 | import org.apache.commons.rng.simple.RandomSource; |
@@ -70,6 +72,9 @@ public enum EmptyClusterStrategy { |
70 | 72 | /** Selected strategy for empty clusters. */ |
71 | 73 | private final EmptyClusterStrategy emptyStrategy; |
72 | 74 |
|
| 75 | + /** Centroid initial algorithm. */ |
| 76 | + private final CentroidInitializer centroidInitializer; |
| 77 | + |
73 | 78 | /** Build a clusterer. |
74 | 79 | * <p> |
75 | 80 | * The default strategy for handling empty clusters that may appear during |
@@ -148,6 +153,8 @@ public KMeansPlusPlusClusterer(final int k, final int maxIterations, |
148 | 153 | this.maxIterations = maxIterations; |
149 | 154 | this.random = random; |
150 | 155 | this.emptyStrategy = emptyStrategy; |
| 156 | + // For KMeansPlusPlusClusterer the centroidInitializer is KMeans++ algorithm. |
| 157 | + this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure,random); |
151 | 158 | } |
152 | 159 |
|
153 | 160 | /** |
@@ -205,7 +212,7 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points) |
205 | 212 | } |
206 | 213 |
|
207 | 214 | // create the initial clusters |
208 | | - List<CentroidCluster<T>> clusters = chooseInitialCenters(points); |
| 215 | + List<CentroidCluster<T>> clusters = centroidInitializer.selectCentroids(points, k); |
209 | 216 |
|
210 | 217 | // create an array containing the latest assignment of a point to a cluster |
211 | 218 | // no need to initialize the array, as it will be filled with the first assignment |
@@ -235,7 +242,7 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points) |
235 | 242 | } |
236 | 243 | emptyCluster = true; |
237 | 244 | } else { |
238 | | - newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length); |
| 245 | + newCenter = ClusterUtils.centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length); |
239 | 246 | } |
240 | 247 | newClusters.add(new CentroidCluster<T>(newCenter)); |
241 | 248 | } |
@@ -278,131 +285,6 @@ private int assignPointsToClusters(final List<CentroidCluster<T>> clusters, |
278 | 285 | return assignedDifferently; |
279 | 286 | } |
280 | 287 |
|
281 | | - /** |
282 | | - * Use K-means++ to choose the initial centers. |
283 | | - * |
284 | | - * @param points the points to choose the initial centers from |
285 | | - * @return the initial centers |
286 | | - */ |
287 | | - private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) { |
288 | | - |
289 | | - // Convert to list for indexed access. Make it unmodifiable, since removal of items |
290 | | - // would screw up the logic of this method. |
291 | | - final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points)); |
292 | | - |
293 | | - // The number of points in the list. |
294 | | - final int numPoints = pointList.size(); |
295 | | - |
296 | | - // Set the corresponding element in this array to indicate when |
297 | | - // elements of pointList are no longer available. |
298 | | - final boolean[] taken = new boolean[numPoints]; |
299 | | - |
300 | | - // The resulting list of initial centers. |
301 | | - final List<CentroidCluster<T>> resultSet = new ArrayList<>(); |
302 | | - |
303 | | - // Choose one center uniformly at random from among the data points. |
304 | | - final int firstPointIndex = random.nextInt(numPoints); |
305 | | - |
306 | | - final T firstPoint = pointList.get(firstPointIndex); |
307 | | - |
308 | | - resultSet.add(new CentroidCluster<T>(firstPoint)); |
309 | | - |
310 | | - // Must mark it as taken |
311 | | - taken[firstPointIndex] = true; |
312 | | - |
313 | | - // To keep track of the minimum distance squared of elements of |
314 | | - // pointList to elements of resultSet. |
315 | | - final double[] minDistSquared = new double[numPoints]; |
316 | | - |
317 | | - // Initialize the elements. Since the only point in resultSet is firstPoint, |
318 | | - // this is very easy. |
319 | | - for (int i = 0; i < numPoints; i++) { |
320 | | - if (i != firstPointIndex) { // That point isn't considered |
321 | | - double d = distance(firstPoint, pointList.get(i)); |
322 | | - minDistSquared[i] = d*d; |
323 | | - } |
324 | | - } |
325 | | - |
326 | | - while (resultSet.size() < k) { |
327 | | - |
328 | | - // Sum up the squared distances for the points in pointList not |
329 | | - // already taken. |
330 | | - double distSqSum = 0.0; |
331 | | - |
332 | | - for (int i = 0; i < numPoints; i++) { |
333 | | - if (!taken[i]) { |
334 | | - distSqSum += minDistSquared[i]; |
335 | | - } |
336 | | - } |
337 | | - |
338 | | - // Add one new data point as a center. Each point x is chosen with |
339 | | - // probability proportional to D(x)2 |
340 | | - final double r = random.nextDouble() * distSqSum; |
341 | | - |
342 | | - // The index of the next point to be added to the resultSet. |
343 | | - int nextPointIndex = -1; |
344 | | - |
345 | | - // Sum through the squared min distances again, stopping when |
346 | | - // sum >= r. |
347 | | - double sum = 0.0; |
348 | | - for (int i = 0; i < numPoints; i++) { |
349 | | - if (!taken[i]) { |
350 | | - sum += minDistSquared[i]; |
351 | | - if (sum >= r) { |
352 | | - nextPointIndex = i; |
353 | | - break; |
354 | | - } |
355 | | - } |
356 | | - } |
357 | | - |
358 | | - // If it's not set to >= 0, the point wasn't found in the previous |
359 | | - // for loop, probably because distances are extremely small. Just pick |
360 | | - // the last available point. |
361 | | - if (nextPointIndex == -1) { |
362 | | - for (int i = numPoints - 1; i >= 0; i--) { |
363 | | - if (!taken[i]) { |
364 | | - nextPointIndex = i; |
365 | | - break; |
366 | | - } |
367 | | - } |
368 | | - } |
369 | | - |
370 | | - // We found one. |
371 | | - if (nextPointIndex >= 0) { |
372 | | - |
373 | | - final T p = pointList.get(nextPointIndex); |
374 | | - |
375 | | - resultSet.add(new CentroidCluster<T> (p)); |
376 | | - |
377 | | - // Mark it as taken. |
378 | | - taken[nextPointIndex] = true; |
379 | | - |
380 | | - if (resultSet.size() < k) { |
381 | | - // Now update elements of minDistSquared. We only have to compute |
382 | | - // the distance to the new center to do this. |
383 | | - for (int j = 0; j < numPoints; j++) { |
384 | | - // Only have to worry about the points still not taken. |
385 | | - if (!taken[j]) { |
386 | | - double d = distance(p, pointList.get(j)); |
387 | | - double d2 = d * d; |
388 | | - if (d2 < minDistSquared[j]) { |
389 | | - minDistSquared[j] = d2; |
390 | | - } |
391 | | - } |
392 | | - } |
393 | | - } |
394 | | - |
395 | | - } else { |
396 | | - // None found -- |
397 | | - // Break from the while loop to prevent |
398 | | - // an infinite loop. |
399 | | - break; |
400 | | - } |
401 | | - } |
402 | | - |
403 | | - return resultSet; |
404 | | - } |
405 | | - |
406 | 288 | /** |
407 | 289 | * Get a random point from the {@link Cluster} with the largest distance variance. |
408 | 290 | * |
@@ -540,26 +422,4 @@ private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, fin |
540 | 422 | } |
541 | 423 | return minCluster; |
542 | 424 | } |
543 | | - |
544 | | - /** |
545 | | - * Computes the centroid for a set of points. |
546 | | - * |
547 | | - * @param points the set of points |
548 | | - * @param dimension the point dimension |
549 | | - * @return the computed centroid for the set of points |
550 | | - */ |
551 | | - private Clusterable centroidOf(final Collection<T> points, final int dimension) { |
552 | | - final double[] centroid = new double[dimension]; |
553 | | - for (final T p : points) { |
554 | | - final double[] point = p.getPoint(); |
555 | | - for (int i = 0; i < centroid.length; i++) { |
556 | | - centroid[i] += point[i]; |
557 | | - } |
558 | | - } |
559 | | - for (int i = 0; i < centroid.length; i++) { |
560 | | - centroid[i] /= points.size(); |
561 | | - } |
562 | | - return new DoublePoint(centroid); |
563 | | - } |
564 | | - |
565 | 425 | } |
0 commit comments