1919
2020import org .apache .commons .math4 .TestUtils ;
2121import org .apache .commons .math4 .exception .NumberIsTooSmallException ;
22+ import org .apache .commons .math4 .ml .clustering .evaluation .ClusterEvaluator ;
23+ import org .apache .commons .math4 .ml .clustering .evaluation .SumOfClusterVariances ;
2224import org .apache .commons .math4 .ml .distance .DistanceMeasure ;
2325import org .apache .commons .math4 .ml .distance .EuclideanDistance ;
2426import org .apache .commons .rng .simple .RandomSource ;
2830import java .util .ArrayList ;
2931import java .util .List ;
3032import java .util .Random ;
31- import java .util .function .Function ;
3233
3334public class MiniBatchKMeansClustererTest {
3435 private final DistanceMeasure measure = new EuclideanDistance ();
@@ -64,17 +65,19 @@ public void testCompareToKMeans() {
6465 Assert .assertEquals (4 , kMeansClusters .size ());
6566 Assert .assertEquals (kMeansClusters .size (), miniBatchKMeansClusters .size ());
6667 int totalDiffCount = 0 ;
67- double totalCenterDistance = 0.0 ;
6868 for (CentroidCluster <DoublePoint > kMeanCluster : kMeansClusters ) {
6969 CentroidCluster <DoublePoint > miniBatchCluster = ClusterUtils .predict (miniBatchKMeansClusters , kMeanCluster .getCenter ());
7070 totalDiffCount += Math .abs (kMeanCluster .getPoints ().size () - miniBatchCluster .getPoints ().size ());
71- totalCenterDistance += measure .compute (kMeanCluster .getCenter ().getPoint (), miniBatchCluster .getCenter ().getPoint ());
7271 }
73- double diffRatio = totalDiffCount * 1.0 / data .size ();
74- System .out .println (String .format ("Centers total distance: %f, clusters total diff points: %d, diff ratio: %f%%" ,
75- totalCenterDistance , totalDiffCount , diffRatio * 100 ));
76- // Sometimes the
77- // Assert.assertTrue(String.format("Different points ratio %f%%!", diffRatio * 100), diffRatio < 0.03);
72+ ClusterEvaluator <DoublePoint > clusterEvaluator = new SumOfClusterVariances <>(measure );
73+ double kMeansScore = clusterEvaluator .score (kMeansClusters );
74+ double miniBatchKMeansScore = clusterEvaluator .score (miniBatchKMeansClusters );
75+ double diffPointsRatio = totalDiffCount * 1.0 / data .size ();
76+ double scoreDiffRatio = (miniBatchKMeansScore - kMeansScore ) /
77+ kMeansScore ;
78+ // MiniBatchKMeansClusterer has few score differences between KMeansClusterer
79+ Assert .assertTrue (String .format ("Different score ratio %f%%!, diff points ratio: %f%%\" " , scoreDiffRatio * 100 , diffPointsRatio * 100 ),
80+ scoreDiffRatio < 0.1 );
7881 }
7982 }
8083
@@ -91,7 +94,7 @@ private List<DoublePoint> generateCircles(int randomSeed) {
9194 List <DoublePoint > generateCircle (int count , double [] center , double radius , Random random ) {
9295 double x0 = center [0 ];
9396 double y0 = center [1 ];
94- ArrayList <DoublePoint > list = new ArrayList <DoublePoint >(count );
97+ ArrayList <DoublePoint > list = new ArrayList <>(count );
9598 for (int i = 0 ; i < count ; i ++) {
9699 double ao = random .nextDouble () * 720 - 360 ;
97100 double r = random .nextDouble () * radius * 2 - radius ;
0 commit comments