Skip to content
141 changes: 141 additions & 0 deletions src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.commons.math4.ml.clustering;

import org.apache.commons.math4.exception.ConvergenceException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.apache.commons.math4.stat.descriptive.moment.Variance;
import org.apache.commons.rng.UniformRandomProvider;

import java.util.Collection;
import java.util.List;

/**
* Common functions used in clustering
*/
public class ClusterUtils {
/**
* Use only for static
*/
private ClusterUtils() {
}

public static final DistanceMeasure DEFAULT_MEASURE = new EuclideanDistance();

/**
* Predict which cluster is best for the point
*
* @param clusters cluster to predict into
* @param point point to predict
* @param measure distance measurer
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
*/
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point, DistanceMeasure measure) {
double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> nearestCluster = null;
for (CentroidCluster<T> cluster : clusters) {
double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint());
if (distance < minDistance) {
minDistance = distance;
nearestCluster = cluster;
}
}
return nearestCluster;
}

/**
* Predict which cluster is best for the point
*
* @param clusters cluster to predict into
* @param point point to predict
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
*/
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point) {
return predict(clusters, point, DEFAULT_MEASURE);
}

/**
* Computes the centroid for a set of points.
*
* @param points the set of points
* @param dimension the point dimension
* @return the computed centroid for the set of points
*/
public static <T extends Clusterable> Clusterable centroidOf(final Collection<T> points, final int dimension) {
final double[] centroid = new double[dimension];
for (final T p : points) {
final double[] point = p.getPoint();
for (int i = 0; i < centroid.length; i++) {
centroid[i] += point[i];
}
}
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= points.size();
}
return new DoublePoint(centroid);
}


/**
* Get a random point from the {@link Cluster} with the largest distance variance.
*
* @param clusters the {@link Cluster}s to search
* @param measure DistanceMeasure
* @param random Random generator
* @return a random point from the selected cluster
* @throws ConvergenceException if clusters are all empty
*/
public static <T extends Clusterable> T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters,
final DistanceMeasure measure,
final UniformRandomProvider random)
throws ConvergenceException {
double maxVariance = Double.NEGATIVE_INFINITY;
Cluster<T> selected = null;
for (final CentroidCluster<T> cluster : clusters) {
if (!cluster.getPoints().isEmpty()) {
// compute the distance variance of the current cluster
final Clusterable center = cluster.getCenter();
final Variance stat = new Variance();
for (final T point : cluster.getPoints()) {
stat.increment(measure.compute(point.getPoint(), center.getPoint()));
}
final double variance = stat.getResult();

// select the cluster with the largest variance
if (variance > maxVariance) {
maxVariance = variance;
selected = cluster;
}

}
}

// did we find at least one non-empty cluster ?
if (selected == null) {
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}

// extract a random point from the cluster
final List<T> selectedPoints = selected.getPoints();
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.commons.math4.exception.MathIllegalArgumentException;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer;
import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.apache.commons.rng.simple.RandomSource;
Expand Down Expand Up @@ -70,6 +72,9 @@ public enum EmptyClusterStrategy {
/** Selected strategy for empty clusters. */
private final EmptyClusterStrategy emptyStrategy;

/** Centroid initial algorithm. */
private final CentroidInitializer centroidInitializer;

/** Build a clusterer.
* <p>
* The default strategy for handling empty clusters that may appear during
Expand Down Expand Up @@ -148,6 +153,8 @@ public KMeansPlusPlusClusterer(final int k, final int maxIterations,
this.maxIterations = maxIterations;
this.random = random;
this.emptyStrategy = emptyStrategy;
// For KMeansPlusPlusClusterer the centroidInitializer is KMeans++ algorithm.
this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure,random);
}

/**
Expand Down Expand Up @@ -205,7 +212,7 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points)
}

// create the initial clusters
List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
List<CentroidCluster<T>> clusters = centroidInitializer.selectCentroids(points, k);

// create an array containing the latest assignment of a point to a cluster
// no need to initialize the array, as it will be filled with the first assignment
Expand Down Expand Up @@ -235,7 +242,7 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points)
}
emptyCluster = true;
} else {
newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
newCenter = ClusterUtils.centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
}
newClusters.add(new CentroidCluster<T>(newCenter));
}
Expand Down Expand Up @@ -278,131 +285,6 @@ private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
return assignedDifferently;
}

/**
* Use K-means++ to choose the initial centers.
*
* @param points the points to choose the initial centers from
* @return the initial centers
*/
private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {

// Convert to list for indexed access. Make it unmodifiable, since removal of items
// would screw up the logic of this method.
final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));

// The number of points in the list.
final int numPoints = pointList.size();

// Set the corresponding element in this array to indicate when
// elements of pointList are no longer available.
final boolean[] taken = new boolean[numPoints];

// The resulting list of initial centers.
final List<CentroidCluster<T>> resultSet = new ArrayList<>();

// Choose one center uniformly at random from among the data points.
final int firstPointIndex = random.nextInt(numPoints);

final T firstPoint = pointList.get(firstPointIndex);

resultSet.add(new CentroidCluster<T>(firstPoint));

// Must mark it as taken
taken[firstPointIndex] = true;

// To keep track of the minimum distance squared of elements of
// pointList to elements of resultSet.
final double[] minDistSquared = new double[numPoints];

// Initialize the elements. Since the only point in resultSet is firstPoint,
// this is very easy.
for (int i = 0; i < numPoints; i++) {
if (i != firstPointIndex) { // That point isn't considered
double d = distance(firstPoint, pointList.get(i));
minDistSquared[i] = d*d;
}
}

while (resultSet.size() < k) {

// Sum up the squared distances for the points in pointList not
// already taken.
double distSqSum = 0.0;

for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
distSqSum += minDistSquared[i];
}
}

// Add one new data point as a center. Each point x is chosen with
// probability proportional to D(x)2
final double r = random.nextDouble() * distSqSum;

// The index of the next point to be added to the resultSet.
int nextPointIndex = -1;

// Sum through the squared min distances again, stopping when
// sum >= r.
double sum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
sum += minDistSquared[i];
if (sum >= r) {
nextPointIndex = i;
break;
}
}
}

// If it's not set to >= 0, the point wasn't found in the previous
// for loop, probably because distances are extremely small. Just pick
// the last available point.
if (nextPointIndex == -1) {
for (int i = numPoints - 1; i >= 0; i--) {
if (!taken[i]) {
nextPointIndex = i;
break;
}
}
}

// We found one.
if (nextPointIndex >= 0) {

final T p = pointList.get(nextPointIndex);

resultSet.add(new CentroidCluster<T> (p));

// Mark it as taken.
taken[nextPointIndex] = true;

if (resultSet.size() < k) {
// Now update elements of minDistSquared. We only have to compute
// the distance to the new center to do this.
for (int j = 0; j < numPoints; j++) {
// Only have to worry about the points still not taken.
if (!taken[j]) {
double d = distance(p, pointList.get(j));
double d2 = d * d;
if (d2 < minDistSquared[j]) {
minDistSquared[j] = d2;
}
}
}
}

} else {
// None found --
// Break from the while loop to prevent
// an infinite loop.
break;
}
}

return resultSet;
}

/**
* Get a random point from the {@link Cluster} with the largest distance variance.
*
Expand Down Expand Up @@ -540,26 +422,4 @@ private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, fin
}
return minCluster;
}

/**
* Computes the centroid for a set of points.
*
* @param points the set of points
* @param dimension the point dimension
* @return the computed centroid for the set of points
*/
private Clusterable centroidOf(final Collection<T> points, final int dimension) {
final double[] centroid = new double[dimension];
for (final T p : points) {
final double[] point = p.getPoint();
for (int i = 0; i < centroid.length; i++) {
centroid[i] += point[i];
}
}
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= points.size();
}
return new DoublePoint(centroid);
}

}
Loading