Skip to content
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
39ce964
first commit
tarang-jain Apr 4, 2025
faad365
s
tarang-jain Apr 7, 2025
372de78
compiles
Apr 12, 2025
4bc668c
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
Apr 12, 2025
49c43fc
mdspan sig
Apr 17, 2025
e41ae23
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
tarang-jain Apr 17, 2025
f99d508
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 17, 2025
0ecb0c8
indptr
tarang-jain Apr 17, 2025
c5e9a32
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
Apr 18, 2025
7ab512a
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
Apr 18, 2025
041ff95
rm duplicate definition
Apr 18, 2025
2738db7
docs
Apr 18, 2025
a58e72d
rm template
Apr 19, 2025
291ef55
Update cpp/src/cluster/detail/single_linkage.cuh
tarang-jain Apr 19, 2025
a35ad32
Update cpp/src/cluster/single_linkage.cuh
tarang-jain Apr 19, 2025
7effdbc
detail namespace
Apr 22, 2025
e0c2171
Merge branch 'branch-25.06' into build-linkage
tarang-jain Apr 22, 2025
93f5866
restructure
tarang-jain Apr 29, 2025
6c15e03
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
tarang-jain Apr 29, 2025
074153a
correct types
tarang-jain Apr 30, 2025
a7c754a
docs
tarang-jain Apr 30, 2025
658c12a
docs
tarang-jain May 2, 2025
ee1d5b2
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
tarang-jain May 2, 2025
6d63e6a
reachability api
tarang-jain May 6, 2025
621448e
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
tarang-jain May 6, 2025
cffc1ec
rev
tarang-jain May 6, 2025
471bdba
rm tparam docs
tarang-jain May 6, 2025
e0de74f
Merge branch 'branch-25.06' into build-linkage
cjnolet May 14, 2025
cbed900
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 1, 2025
e8fa16c
build-linkage-api
tarang-jain Jun 1, 2025
b43d3e6
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
tarang-jain Jun 1, 2025
71c3621
new API
tarang-jain Jun 1, 2025
0c06116
docs
tarang-jain Jun 1, 2025
b34309e
param struct;fix compilation errors
tarang-jain Jun 2, 2025
f293d90
unused headers
tarang-jain Jun 2, 2025
9a4d68e
docs,copyright
tarang-jain Jun 3, 2025
62ba6fb
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 3, 2025
f28cd9d
copyright
tarang-jain Jun 3, 2025
b122b61
fix failing LinkageTest
tarang-jain Jun 4, 2025
b624433
docs
tarang-jain Jun 4, 2025
1e9df56
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 5, 2025
0614806
correct size
tarang-jain Jun 5, 2025
8ccac54
fix failing hdbscan test
tarang-jain Jun 5, 2025
3067be2
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 5, 2025
90969c3
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
tarang-jain Jun 5, 2025
1dd9039
signature;log statements
tarang-jain Jun 5, 2025
5426b35
rm debug statements
tarang-jain Jun 5, 2025
3b0737d
fix ci warning
tarang-jain Jun 5, 2025
4dc2b6d
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 6, 2025
0be915c
rm log statements
tarang-jain Jun 8, 2025
6afb096
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 8, 2025
b23f6ae
skip mst alloc
tarang-jain Jun 9, 2025
63cebc8
reference issue
tarang-jain Jun 9, 2025
55da78d
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 10, 2025
7090f31
update comment
tarang-jain Jun 10, 2025
6595a38
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 10, 2025
b86b6bd
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
tarang-jain Jun 10, 2025
dcd495f
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 16, 2025
6434fee
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 30, 2025
63e65a4
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion cpp/include/cuvs/cluster/agglomerative.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,9 @@

#include <cuvs/distance/distance.hpp>
#include <optional>
#include <variant>

#include <raft/core/device_coo_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

Expand Down Expand Up @@ -119,6 +121,58 @@ void single_linkage(
cuvs::cluster::agglomerative::Linkage linkage = cuvs::cluster::agglomerative::Linkage::KNN_GRAPH,
std::optional<int> c = std::make_optional<int>(DEFAULT_CONST_C));

namespace helpers {

namespace linkage_graph_params {
/** Specialized parameters to build the KNN graph with regular distances */
struct distance_params {
/** a constant used when constructing linkage from knn graph. Allows the indirect control of k.
* The algorithm will set `k = log(n) + c` */
int c = DEFAULT_CONST_C;

/** strategy for constructing the linkage. PAIRWISE uses more memory but can be faster for smaller
* datasets. KNN_GRAPH allows the memory usage to be controlled (using parameter c) */
cuvs::cluster::agglomerative::Linkage dist_type =
cuvs::cluster::agglomerative::Linkage::KNN_GRAPH;
};

/** Specialized parameters to build the Mutual Reachability graph */
struct mutual_reachability_params {
/** this neighborhood will be selected for core distances. */
int min_samples;

/** weight applied when internal distance is chosen for mutual reachability (value of 1.0 disables
* the weighting) */
float alpha = 1.0;
};
} // namespace linkage_graph_params
/**
* Given a dataset, builds the KNN graph, connects graph components and builds a linkage
* (dendrogram). Returns the Minimum Spanning Tree edges sorted by weight and the dendrogram.
* @param[in] handle raft handle for resource reuse
* @param[in] X data points (size n_rows * d)
* @param[in] linkage_graph_params linkage params or mutual reachability params for building the KNN
* graph
* @param[in] metric distance metric to use
* @param[out] out_mst output MST sorted by edge weights (size n_rows - 1)
* @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2)
* @param[out] out_distances distances for output
* @param[out] out_sizes cluster sizes of output
* @param[out] core_dists (optional) core distances (size m). Must be supplied in the Mutual
* Reachability space
*/
void build_linkage(
raft::resources const& handle,
raft::device_matrix_view<const float, int, raft::row_major> X,
std::variant<linkage_graph_params::distance_params,
linkage_graph_params::mutual_reachability_params> linkage_graph_params,
cuvs::distance::DistanceType metric,
raft::device_coo_matrix_view<float, int, int, size_t> out_mst,
raft::device_matrix_view<int, int> dendrogram,
raft::device_vector_view<float, int> out_distances,
raft::device_vector_view<int, int> out_sizes,
std::optional<raft::device_vector_view<float, int>> core_dists);
} // namespace helpers
/**
* @}
*/
Expand Down
247 changes: 197 additions & 50 deletions cpp/src/cluster/detail/single_linkage.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,51 +16,136 @@

#pragma once

#include "../../neighbors/detail/reachability.cuh"
#include "agglomerative.cuh"
#include "connectivities.cuh"
#include "mst.cuh"
#include <cuvs/cluster/agglomerative.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/sparse/coo.hpp>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>

namespace cuvs::cluster::agglomerative::detail {

/**
* Constructs a linkage by computing the minimum spanning tree and dendrogram in the Mutual
* Reachability space. Returns mst edges sorted by weight and the dendrogram.
* @tparam value_t
* @tparam value_idx
* @tparam nnz_t
* @param[in] handle raft handle for resource reuse
* @param[in] X data points (size m * n)
* @param[in] metric distance metric to use
* @param[in] min_samples this neighborhood will be selected for core distances
* @param[in] alpha weight applied when internal distance is chosen for mutual reachability (value
* of 1.0 disables the weighting)
* @param[out] core_dists core distances (size m)
* @param[out] out_mst output MST sorted by edge weights (size m - 1)
* @param[out] out_dendrogram output dendrogram
* @param[out] out_distances distances of output
* @param[out] out_sizes cluster sizes of output
*/
template <typename value_t = float, typename value_idx = int, typename nnz_t = size_t>
void build_mr_linkage(raft::resources const& handle,
raft::device_matrix_view<const value_t, value_idx, raft::row_major> X,
value_idx min_samples,
float alpha,
cuvs::distance::DistanceType metric,
raft::device_vector_view<value_t, value_idx> core_dists,
raft::device_coo_matrix_view<value_t, value_idx, value_idx, nnz_t> out_mst,
raft::device_matrix_view<value_idx, value_idx> out_dendrogram,
raft::device_vector_view<value_t, value_idx> out_distances,
raft::device_vector_view<value_idx, value_idx> out_sizes)
{
size_t m = X.extent(0);
size_t n = X.extent(1);
auto mutual_reachability_indptr = raft::make_device_vector<value_idx, value_idx>(handle, m + 1);
raft::sparse::COO<value_t, value_idx, nnz_t> mutual_reachability_coo(
raft::resource::get_cuda_stream(handle), min_samples * m * 2);

cuvs::neighbors::detail::reachability::mutual_reachability_graph<value_idx, value_t, nnz_t>(
handle,
X.data_handle(),
m,
n,
metric,
min_samples,
alpha,
mutual_reachability_indptr.data_handle(),
core_dists.data_handle(),
mutual_reachability_coo);

// auto color = raft::make_device_vector<value_idx, value_idx>(handle, static_cast<value_idx>(m));
rmm::device_uvector<value_idx> color(m, raft::resource::get_cuda_stream(handle));
cuvs::sparse::neighbors::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t>
reduction_op(core_dists.data_handle(), m);

size_t nnz = m * min_samples;

detail::build_sorted_mst<value_idx, value_t>(handle,
X.data_handle(),
mutual_reachability_indptr.data_handle(),
mutual_reachability_coo.cols(),
mutual_reachability_coo.vals(),
m,
n,
out_mst.structure_view().get_rows().data(),
out_mst.structure_view().get_cols().data(),
out_mst.get_elements().data(),
color.data(),
mutual_reachability_coo.nnz,
reduction_op,
metric,
10);

/**
* Perform hierarchical labeling
*/
size_t n_edges = m - 1;

detail::build_dendrogram_host<value_idx, value_t>(handle,
out_mst.structure_view().get_rows().data(),
out_mst.structure_view().get_cols().data(),
out_mst.get_elements().data(),
n_edges,
out_dendrogram.data_handle(),
out_distances.data_handle(),
out_sizes.data_handle());
}

static const size_t EMPTY = 0;

/**
* Single-linkage clustering, capable of constructing a KNN graph to
* scale the algorithm beyond the n^2 memory consumption of implementations
* that use the fully-connected graph of pairwise distances by connecting
* a knn graph when k is not large enough to connect it.

* @tparam value_idx
* Constructs a linkage by computing the minimum spanning tree and dendrogram in the Mutual
* Reachability space. Returns mst edges sorted by weight and the dendrogram.
* @tparam value_t
* @tparam value_idx
* @tparam nnz_t
* @tparam dist_type method to use for constructing connectivities graph
* @param[in] handle raft handle
* @param[in] X dense input matrix in row-major layout
* @param[in] m number of rows in X
* @param[in] n number of columns in X
* @param[in] metric distance metrix to use when constructing connectivities graph
* @param[out] out struct containing output dendrogram and cluster assignments
* @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect
control
* of k. The algorithm will set `k = log(n) + c`
* @param[in] n_clusters number of clusters to assign data samples
* @param[in] handle raft handle for resource reuse
* @param[in] X data points (size m * n)
* @param[in] c a constant used when constructing linkage from knn graph. Allows the indirect
* control of k. The algorithm will set `k = log(n) + c`
* @param[in] metric distance metric to use
* @param[out] out_mst output MST sorted by edge weights (size m - 1)
* @param[out] out_dendrogram output dendrogram
* @param[out] out_distances distances of output
* @param[out] out_sizes cluster sizes of output
*/
template <typename value_idx, typename value_t, Linkage dist_type>
void single_linkage(raft::resources const& handle,
const value_t* X,
size_t m,
size_t n,
cuvs::distance::DistanceType metric,
single_linkage_output<value_idx>* out,
int c,
size_t n_clusters)
template <typename value_t, typename value_idx, typename nnz_t, Linkage dist_type>
void build_dist_linkage(raft::resources const& handle,
raft::device_matrix_view<const value_t, value_idx, raft::row_major> X,
int c,
cuvs::distance::DistanceType metric,
raft::device_coo_matrix_view<value_t, value_idx, value_idx, nnz_t> out_mst,
raft::device_matrix_view<value_idx, value_idx> out_dendrogram,
raft::device_vector_view<value_t, value_idx> out_distances,
raft::device_vector_view<value_idx, value_idx> out_sizes)
{
ASSERT(n_clusters <= m, "n_clusters must be less than or equal to the number of data points");

size_t m = X.extent(0);
size_t n = X.extent(1);
auto stream = raft::resource::get_cuda_stream(handle);

rmm::device_uvector<value_idx> indptr(EMPTY, stream);
Expand All @@ -70,51 +155,113 @@ void single_linkage(raft::resources const& handle,
/**
* 1. Construct distance graph
*/
detail::get_distance_graph<value_idx, value_t, dist_type>(
handle, X, m, n, metric, indptr, indices, pw_dists, c);

rmm::device_uvector<value_idx> mst_rows(m - 1, stream);
rmm::device_uvector<value_idx> mst_cols(m - 1, stream);
rmm::device_uvector<value_t> mst_data(m - 1, stream);
detail::get_distance_graph<value_idx, value_t, dist_type>(handle,
X.data_handle(),
static_cast<value_idx>(m),
static_cast<value_idx>(n),
metric,
indptr,
indices,
pw_dists,
c);

/**
* 2. Construct MST, sorted by weights
*/
rmm::device_uvector<value_idx> color(m, stream);
cuvs::sparse::neighbors::FixConnectivitiesRedOp<value_idx, value_t> op(m);

size_t n_edges = m - 1;

rmm::device_uvector<value_idx> mst_rows(n_edges, stream);
rmm::device_uvector<value_idx> mst_cols(n_edges, stream);
rmm::device_uvector<value_t> mst_data(n_edges, stream);

detail::build_sorted_mst<value_idx, value_t>(handle,
X,
X.data_handle(),
indptr.data(),
indices.data(),
pw_dists.data(),
m,
n,
mst_rows.data(),
mst_cols.data(),
mst_data.data(),
out_mst.structure_view().get_rows().data(),
out_mst.structure_view().get_cols().data(),
out_mst.get_elements().data(),
color.data(),
indices.size(),
op,
metric);

pw_dists.release();

/**
* Perform hierarchical labeling
*/
size_t n_edges = mst_rows.size();

rmm::device_uvector<value_t> out_delta(n_edges, stream);
rmm::device_uvector<value_idx> out_size(n_edges, stream);
// Create dendrogram
detail::build_dendrogram_host<value_idx, value_t>(handle,
mst_rows.data(),
mst_cols.data(),
mst_data.data(),
out_mst.structure_view().get_rows().data(),
out_mst.structure_view().get_cols().data(),
out_mst.get_elements().data(),
n_edges,
out->children,
out_delta.data(),
out_size.data());
out_dendrogram.data_handle(),
out_distances.data_handle(),
out_sizes.data_handle());
}

/**
* Single-linkage clustering, capable of constructing a KNN graph to
* scale the algorithm beyond the n^2 memory consumption of implementations
* that use the fully-connected graph of pairwise distances by connecting
* a knn graph when k is not large enough to connect it.

* @tparam value_idx
* @tparam value_t
* @tparam dist_type method to use for constructing connectivities graph
* @param[in] handle raft handle
* @param[in] X dense input matrix in row-major layout
* @param[in] m number of rows in X
* @param[in] n number of columns in X
* @param[in] metric distance metrix to use when constructing connectivities graph
* @param[out] out struct containing output dendrogram and cluster assignments
* @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect
control
* of k. The algorithm will set `k = log(n) + c`
* @param[in] n_clusters number of clusters to assign data samples
*/
template <typename value_idx, typename value_t, Linkage dist_type>
void single_linkage(raft::resources const& handle,
const value_t* X,
size_t m,
size_t n,
cuvs::distance::DistanceType metric,
single_linkage_output<value_idx>* out,
int c,
size_t n_clusters)
{
ASSERT(n_clusters <= m, "n_clusters must be less than or equal to the number of data points");

value_idx n_edges = m - 1;
auto mst_rows = raft::make_device_vector<value_idx, value_idx>(handle, n_edges);
auto mst_cols = raft::make_device_vector<value_idx, value_idx>(handle, n_edges);
auto mst_weights = raft::make_device_vector<value_t, value_idx>(handle, n_edges);
auto structure_view =
raft::make_device_coordinate_structure_view<value_idx, value_idx, value_idx>(
mst_rows.data_handle(), mst_cols.data_handle(), m, m, n_edges);
auto mst_view = raft::make_device_coo_matrix_view<value_t, value_idx, value_idx, value_idx>(
mst_weights.data_handle(), structure_view);

auto out_delta = raft::make_device_vector<value_t, value_idx>(handle, n_edges);
auto out_sizes = raft::make_device_vector<value_idx, value_idx>(handle, n_edges);

build_dist_linkage<value_t, value_idx, value_idx, dist_type>(
handle,
raft::make_device_matrix_view<const value_t, value_idx, raft::row_major>(
X, static_cast<value_idx>(m), static_cast<value_idx>(n)),
c,
metric,
mst_view,
raft::make_device_matrix_view<value_idx, value_idx, raft::row_major>(out->children, n_edges, 2),
out_delta.view(),
out_sizes.view());

detail::extract_flattened_clusters(handle, out->labels, out->children, n_clusters, m);

out->m = m;
Expand Down
Loading