Skip to content

Commit 7288cc3

Browse files
author
CT
committed
Improvement the tests of MiniBatchKMeansClusterer, compare with KMeansPlusPlusClusterer with score.
1 parent adc62f6 commit 7288cc3

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

src/test/java/org/apache/commons/math4/ml/clustering/MiniBatchKMeansClustererTest.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import org.apache.commons.math4.TestUtils;
2121
import org.apache.commons.math4.exception.NumberIsTooSmallException;
22+
import org.apache.commons.math4.ml.clustering.evaluation.ClusterEvaluator;
23+
import org.apache.commons.math4.ml.clustering.evaluation.SumOfClusterVariances;
2224
import org.apache.commons.math4.ml.distance.DistanceMeasure;
2325
import org.apache.commons.math4.ml.distance.EuclideanDistance;
2426
import org.apache.commons.rng.simple.RandomSource;
@@ -28,7 +30,6 @@
2830
import java.util.ArrayList;
2931
import java.util.List;
3032
import java.util.Random;
31-
import java.util.function.Function;
3233

3334
public class MiniBatchKMeansClustererTest {
3435
private final DistanceMeasure measure = new EuclideanDistance();
@@ -64,17 +65,19 @@ public void testCompareToKMeans() {
6465
Assert.assertEquals(4, kMeansClusters.size());
6566
Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size());
6667
int totalDiffCount = 0;
67-
double totalCenterDistance = 0.0;
6868
for (CentroidCluster<DoublePoint> kMeanCluster : kMeansClusters) {
6969
CentroidCluster<DoublePoint> miniBatchCluster = ClusterUtils.predict(miniBatchKMeansClusters, kMeanCluster.getCenter());
7070
totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size());
71-
totalCenterDistance += measure.compute(kMeanCluster.getCenter().getPoint(), miniBatchCluster.getCenter().getPoint());
7271
}
73-
double diffRatio = totalDiffCount * 1.0 / data.size();
74-
System.out.println(String.format("Centers total distance: %f, clusters total diff points: %d, diff ratio: %f%%",
75-
totalCenterDistance, totalDiffCount, diffRatio * 100));
76-
// Sometimes the
77-
// Assert.assertTrue(String.format("Different points ratio %f%%!", diffRatio * 100), diffRatio < 0.03);
72+
ClusterEvaluator<DoublePoint> clusterEvaluator = new SumOfClusterVariances<>(measure);
73+
double kMeansScore = clusterEvaluator.score(kMeansClusters);
74+
double miniBatchKMeansScore = clusterEvaluator.score(miniBatchKMeansClusters);
75+
double diffPointsRatio = totalDiffCount * 1.0 / data.size();
76+
double scoreDiffRatio = (miniBatchKMeansScore - kMeansScore) /
77+
kMeansScore;
78+
// MiniBatchKMeansClusterer has few score differences between KMeansClusterer
79+
Assert.assertTrue(String.format("Different score ratio %f%%!, diff points ratio: %f%%\"", scoreDiffRatio * 100, diffPointsRatio * 100),
80+
scoreDiffRatio < 0.1);
7881
}
7982
}
8083

@@ -91,7 +94,7 @@ private List<DoublePoint> generateCircles(int randomSeed) {
9194
List<DoublePoint> generateCircle(int count, double[] center, double radius, Random random) {
9295
double x0 = center[0];
9396
double y0 = center[1];
94-
ArrayList<DoublePoint> list = new ArrayList<DoublePoint>(count);
97+
ArrayList<DoublePoint> list = new ArrayList<>(count);
9598
for (int i = 0; i < count; i++) {
9699
double ao = random.nextDouble() * 720 - 360;
97100
double r = random.nextDouble() * radius * 2 - radius;

0 commit comments

Comments
 (0)