Skip to content

Commit 2fe2e88

Browse files
authored
Add multigpu kmeans fit function (#348)
Changes to support using kmeans clustering inside of cuml, so we can transition cuml off of the RAFT kmeans code * Add a multigpu kmeans fit function * Adds instantiations for kmeans on int64_t indicies, which unfortunately also requires int64_t indices for the PW distance functions * Add support for `double` precision kmeans Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #348
1 parent ce01a0b commit 2fe2e88

25 files changed

+2059
-43
lines changed

cpp/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,22 @@ target_compile_options(
290290
add_library(
291291
cuvs SHARED
292292
src/cluster/kmeans_balanced_fit_float.cu
293+
src/cluster/kmeans_fit_mg_float.cu
294+
src/cluster/kmeans_fit_mg_double.cu
295+
src/cluster/kmeans_fit_double.cu
293296
src/cluster/kmeans_fit_float.cu
294297
src/cluster/kmeans_auto_find_k_float.cu
298+
src/cluster/kmeans_fit_predict_double.cu
295299
src/cluster/kmeans_fit_predict_float.cu
300+
src/cluster/kmeans_predict_double.cu
296301
src/cluster/kmeans_predict_float.cu
297302
src/cluster/kmeans_balanced_fit_float.cu
298303
src/cluster/kmeans_balanced_fit_predict_float.cu
299304
src/cluster/kmeans_balanced_predict_float.cu
300305
src/cluster/kmeans_balanced_fit_int8.cu
301306
src/cluster/kmeans_balanced_fit_predict_int8.cu
302307
src/cluster/kmeans_balanced_predict_int8.cu
308+
src/cluster/kmeans_transform_double.cu
303309
src/cluster/kmeans_transform_float.cu
304310
src/cluster/single_linkage_float.cu
305311
src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu
@@ -342,6 +348,8 @@ add_library(
342348
src/distance/detail/pairwise_matrix/dispatch_russel_rao_half_float_float_int.cu
343349
src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu
344350
src/distance/detail/pairwise_matrix/dispatch_rbf.cu
351+
src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu
352+
src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu
345353
src/distance/detail/fused_distance_nn.cu
346354
src/distance/distance.cu
347355
src/distance/pairwise_distance.cu

cpp/include/cuvs/cluster/kmeans.hpp

Lines changed: 485 additions & 21 deletions
Large diffs are not rendered by default.

cpp/src/cluster/detail/connectivities.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#pragma once
1818

19-
#include "../../distance/distance.cuh"
19+
#include "./kmeans_common.cuh"
2020
#include <cuvs/cluster/agglomerative.hpp>
2121
#include <cuvs/distance/distance.hpp>
2222
#include <raft/core/resource/cuda_stream.hpp>
@@ -153,7 +153,11 @@ void pairwise_distances(const raft::resources& handle,
153153
// TODO: It would ultimately be nice if the MST could accept
154154
// dense inputs directly so we don't need to double the memory
155155
// usage to hand it a sparse array here.
156-
distance::pairwise_distance<value_t, value_idx>(handle, X, X, data, m, m, n, metric);
156+
auto X_view = raft::make_device_matrix_view<const value_t, value_idx>(X, m, n);
157+
158+
cuvs::cluster::kmeans::detail::pairwise_distance_kmeans<value_t, value_idx>(
159+
handle, X_view, X_view, raft::make_device_matrix_view<value_t, value_idx>(data, m, m), metric);
160+
157161
// self-loops get max distance
158162
auto transform_in =
159163
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data));

cpp/src/cluster/detail/kmeans.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ void kmeansPlusPlus(raft::resources const& handle,
198198
// Output - pwd [n_trials x n_samples]
199199
auto pwd = distBuffer.view();
200200
cuvs::cluster::kmeans::detail::pairwise_distance_kmeans<DataT, IndexT>(
201-
handle, centroidCandidates.view(), X, pwd, workspace, metric);
201+
handle, centroidCandidates.view(), X, pwd, metric);
202202

203203
// Update nearest cluster distance for each centroid candidate
204204
// Note pwd and minDistBuf points to same buffer which currently holds pairwise distance values.
@@ -1247,7 +1247,7 @@ void kmeans_transform(raft::resources const& handle,
12471247
// calculate pairwise distance between cluster centroids and current batch
12481248
// of input dataset
12491249
pairwise_distance_kmeans<DataT, IndexT>(
1250-
handle, datasetView, centroids, pairwiseDistanceView, workspace, metric);
1250+
handle, datasetView, centroids, pairwiseDistanceView, metric);
12511251
}
12521252
}
12531253

cpp/src/cluster/detail/kmeans_common.cuh

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ void pairwise_distance_kmeans(raft::resources const& handle,
293293
raft::device_matrix_view<const DataT, IndexT> X,
294294
raft::device_matrix_view<const DataT, IndexT> centroids,
295295
raft::device_matrix_view<DataT, IndexT> pairwiseDistance,
296-
rmm::device_uvector<char>& workspace,
297296
cuvs::distance::DistanceType metric)
298297
{
299298
auto n_samples = X.extent(0);
@@ -303,15 +302,23 @@ void pairwise_distance_kmeans(raft::resources const& handle,
303302
ASSERT(X.extent(1) == centroids.extent(1),
304303
"# features in dataset and centroids are different (must be same)");
305304

306-
cuvs::distance::pairwise_distance(handle,
307-
X.data_handle(),
308-
centroids.data_handle(),
309-
pairwiseDistance.data_handle(),
310-
n_samples,
311-
n_clusters,
312-
n_features,
313-
workspace,
314-
metric);
305+
if (metric == cuvs::distance::DistanceType::L2Expanded) {
306+
cuvs::distance::distance<cuvs::distance::DistanceType::L2Expanded,
307+
DataT,
308+
DataT,
309+
DataT,
310+
raft::layout_c_contiguous,
311+
IndexT>(handle, X, centroids, pairwiseDistance);
312+
} else if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
313+
cuvs::distance::distance<cuvs::distance::DistanceType::L2SqrtExpanded,
314+
DataT,
315+
DataT,
316+
DataT,
317+
raft::layout_c_contiguous,
318+
IndexT>(handle, X, centroids, pairwiseDistance);
319+
} else {
320+
RAFT_FAIL("kmeans requires L2Expanded or L2SqrtExpanded distance, have %i", metric);
321+
}
315322
}
316323

317324
// shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores
@@ -461,7 +468,7 @@ void minClusterAndDistanceCompute(
461468
// calculate pairwise distance between current tile of cluster centroids
462469
// and input dataset
463470
pairwise_distance_kmeans<DataT, IndexT>(
464-
handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric);
471+
handle, datasetView, centroidsView, pairwiseDistanceView, metric);
465472

466473
// argmin reduction returning <index, value> pair
467474
// calculates the closest centroid and the distance to the closest
@@ -591,7 +598,7 @@ void minClusterDistanceCompute(raft::resources const& handle,
591598
// calculate pairwise distance between current tile of cluster centroids
592599
// and input dataset
593600
pairwise_distance_kmeans<DataT, IndexT>(
594-
handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric);
601+
handle, datasetView, centroidsView, pairwiseDistanceView, metric);
595602

596603
raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(),
597604
pairwiseDistanceView.data_handle(),

0 commit comments

Comments
 (0)