Skip to content

Commit eeb4110

Browse files
author
CT
committed
Remove duplicate code which produced by MiniBatchKMeansClusterer.
1 parent 7288cc3 commit eeb4110

File tree

1 file changed

+9
-149
lines changed

1 file changed

+9
-149
lines changed

src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java

Lines changed: 9 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.apache.commons.math4.exception.MathIllegalArgumentException;
2727
import org.apache.commons.math4.exception.NumberIsTooSmallException;
2828
import org.apache.commons.math4.exception.util.LocalizedFormats;
29+
import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer;
30+
import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer;
2931
import org.apache.commons.math4.ml.distance.DistanceMeasure;
3032
import org.apache.commons.math4.ml.distance.EuclideanDistance;
3133
import org.apache.commons.rng.simple.RandomSource;
@@ -70,6 +72,9 @@ public enum EmptyClusterStrategy {
7072
/** Selected strategy for empty clusters. */
7173
private final EmptyClusterStrategy emptyStrategy;
7274

75+
/** Centroid initial algorithm. */
76+
private final CentroidInitializer centroidInitializer;
77+
7378
/** Build a clusterer.
7479
* <p>
7580
* The default strategy for handling empty clusters that may appear during
@@ -148,6 +153,8 @@ public KMeansPlusPlusClusterer(final int k, final int maxIterations,
148153
this.maxIterations = maxIterations;
149154
this.random = random;
150155
this.emptyStrategy = emptyStrategy;
156+
// For KMeansPlusPlusClusterer the centroidInitializer is KMeans++ algorithm.
157+
this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure,random);
151158
}
152159

153160
/**
@@ -205,7 +212,7 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points)
205212
}
206213

207214
// create the initial clusters
208-
List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
215+
List<CentroidCluster<T>> clusters = centroidInitializer.selectCentroids(points, k);
209216

210217
// create an array containing the latest assignment of a point to a cluster
211218
// no need to initialize the array, as it will be filled with the first assignment
@@ -235,7 +242,7 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points)
235242
}
236243
emptyCluster = true;
237244
} else {
238-
newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
245+
newCenter = ClusterUtils.centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
239246
}
240247
newClusters.add(new CentroidCluster<T>(newCenter));
241248
}
@@ -278,131 +285,6 @@ private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
278285
return assignedDifferently;
279286
}
280287

281-
/**
282-
* Use K-means++ to choose the initial centers.
283-
*
284-
* @param points the points to choose the initial centers from
285-
* @return the initial centers
286-
*/
287-
private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
288-
289-
// Convert to list for indexed access. Make it unmodifiable, since removal of items
290-
// would screw up the logic of this method.
291-
final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
292-
293-
// The number of points in the list.
294-
final int numPoints = pointList.size();
295-
296-
// Set the corresponding element in this array to indicate when
297-
// elements of pointList are no longer available.
298-
final boolean[] taken = new boolean[numPoints];
299-
300-
// The resulting list of initial centers.
301-
final List<CentroidCluster<T>> resultSet = new ArrayList<>();
302-
303-
// Choose one center uniformly at random from among the data points.
304-
final int firstPointIndex = random.nextInt(numPoints);
305-
306-
final T firstPoint = pointList.get(firstPointIndex);
307-
308-
resultSet.add(new CentroidCluster<T>(firstPoint));
309-
310-
// Must mark it as taken
311-
taken[firstPointIndex] = true;
312-
313-
// To keep track of the minimum distance squared of elements of
314-
// pointList to elements of resultSet.
315-
final double[] minDistSquared = new double[numPoints];
316-
317-
// Initialize the elements. Since the only point in resultSet is firstPoint,
318-
// this is very easy.
319-
for (int i = 0; i < numPoints; i++) {
320-
if (i != firstPointIndex) { // That point isn't considered
321-
double d = distance(firstPoint, pointList.get(i));
322-
minDistSquared[i] = d*d;
323-
}
324-
}
325-
326-
while (resultSet.size() < k) {
327-
328-
// Sum up the squared distances for the points in pointList not
329-
// already taken.
330-
double distSqSum = 0.0;
331-
332-
for (int i = 0; i < numPoints; i++) {
333-
if (!taken[i]) {
334-
distSqSum += minDistSquared[i];
335-
}
336-
}
337-
338-
// Add one new data point as a center. Each point x is chosen with
339-
// probability proportional to D(x)2
340-
final double r = random.nextDouble() * distSqSum;
341-
342-
// The index of the next point to be added to the resultSet.
343-
int nextPointIndex = -1;
344-
345-
// Sum through the squared min distances again, stopping when
346-
// sum >= r.
347-
double sum = 0.0;
348-
for (int i = 0; i < numPoints; i++) {
349-
if (!taken[i]) {
350-
sum += minDistSquared[i];
351-
if (sum >= r) {
352-
nextPointIndex = i;
353-
break;
354-
}
355-
}
356-
}
357-
358-
// If it's not set to >= 0, the point wasn't found in the previous
359-
// for loop, probably because distances are extremely small. Just pick
360-
// the last available point.
361-
if (nextPointIndex == -1) {
362-
for (int i = numPoints - 1; i >= 0; i--) {
363-
if (!taken[i]) {
364-
nextPointIndex = i;
365-
break;
366-
}
367-
}
368-
}
369-
370-
// We found one.
371-
if (nextPointIndex >= 0) {
372-
373-
final T p = pointList.get(nextPointIndex);
374-
375-
resultSet.add(new CentroidCluster<T> (p));
376-
377-
// Mark it as taken.
378-
taken[nextPointIndex] = true;
379-
380-
if (resultSet.size() < k) {
381-
// Now update elements of minDistSquared. We only have to compute
382-
// the distance to the new center to do this.
383-
for (int j = 0; j < numPoints; j++) {
384-
// Only have to worry about the points still not taken.
385-
if (!taken[j]) {
386-
double d = distance(p, pointList.get(j));
387-
double d2 = d * d;
388-
if (d2 < minDistSquared[j]) {
389-
minDistSquared[j] = d2;
390-
}
391-
}
392-
}
393-
}
394-
395-
} else {
396-
// None found --
397-
// Break from the while loop to prevent
398-
// an infinite loop.
399-
break;
400-
}
401-
}
402-
403-
return resultSet;
404-
}
405-
406288
/**
407289
* Get a random point from the {@link Cluster} with the largest distance variance.
408290
*
@@ -540,26 +422,4 @@ private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, fin
540422
}
541423
return minCluster;
542424
}
543-
544-
/**
545-
* Computes the centroid for a set of points.
546-
*
547-
* @param points the set of points
548-
* @param dimension the point dimension
549-
* @return the computed centroid for the set of points
550-
*/
551-
private Clusterable centroidOf(final Collection<T> points, final int dimension) {
552-
final double[] centroid = new double[dimension];
553-
for (final T p : points) {
554-
final double[] point = p.getPoint();
555-
for (int i = 0; i < centroid.length; i++) {
556-
centroid[i] += point[i];
557-
}
558-
}
559-
for (int i = 0; i < centroid.length; i++) {
560-
centroid[i] /= points.size();
561-
}
562-
return new DoublePoint(centroid);
563-
}
564-
565425
}

0 commit comments

Comments
 (0)