Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
39ce964
first commit
tarang-jain Apr 4, 2025
faad365
s
tarang-jain Apr 7, 2025
372de78
compiles
Apr 12, 2025
4bc668c
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
Apr 12, 2025
49c43fc
mdspan sig
Apr 17, 2025
e41ae23
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
tarang-jain Apr 17, 2025
f99d508
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 17, 2025
0ecb0c8
indptr
tarang-jain Apr 17, 2025
c5e9a32
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
Apr 18, 2025
7ab512a
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
Apr 18, 2025
041ff95
rm duplicate definition
Apr 18, 2025
2738db7
docs
Apr 18, 2025
a58e72d
rm template
Apr 19, 2025
291ef55
Update cpp/src/cluster/detail/single_linkage.cuh
tarang-jain Apr 19, 2025
a35ad32
Update cpp/src/cluster/single_linkage.cuh
tarang-jain Apr 19, 2025
7effdbc
detail namespace
Apr 22, 2025
e0c2171
Merge branch 'branch-25.06' into build-linkage
tarang-jain Apr 22, 2025
93f5866
restructure
tarang-jain Apr 29, 2025
6c15e03
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
tarang-jain Apr 29, 2025
074153a
correct types
tarang-jain Apr 30, 2025
a7c754a
docs
tarang-jain Apr 30, 2025
658c12a
docs
tarang-jain May 2, 2025
ee1d5b2
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
tarang-jain May 2, 2025
6d63e6a
reachability api
tarang-jain May 6, 2025
621448e
Merge branch 'branch-25.06' of https://github.com/rapidsai/cuvs into …
tarang-jain May 6, 2025
cffc1ec
rev
tarang-jain May 6, 2025
471bdba
rm tparam docs
tarang-jain May 6, 2025
e0de74f
Merge branch 'branch-25.06' into build-linkage
cjnolet May 14, 2025
cbed900
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 1, 2025
e8fa16c
build-linkage-api
tarang-jain Jun 1, 2025
b43d3e6
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
tarang-jain Jun 1, 2025
71c3621
new API
tarang-jain Jun 1, 2025
0c06116
docs
tarang-jain Jun 1, 2025
b34309e
param struct;fix compilation errors
tarang-jain Jun 2, 2025
f293d90
unused headers
tarang-jain Jun 2, 2025
9a4d68e
docs,copyright
tarang-jain Jun 3, 2025
62ba6fb
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 3, 2025
f28cd9d
copyright
tarang-jain Jun 3, 2025
b122b61
fix failing LinkageTest
tarang-jain Jun 4, 2025
b624433
docs
tarang-jain Jun 4, 2025
1e9df56
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 5, 2025
0614806
correct size
tarang-jain Jun 5, 2025
8ccac54
fix failing hdbscan test
tarang-jain Jun 5, 2025
3067be2
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 5, 2025
90969c3
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
tarang-jain Jun 5, 2025
1dd9039
signature;log statements
tarang-jain Jun 5, 2025
5426b35
rm debug statements
tarang-jain Jun 5, 2025
3b0737d
fix ci warning
tarang-jain Jun 5, 2025
4dc2b6d
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 6, 2025
0be915c
rm log statements
tarang-jain Jun 8, 2025
6afb096
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 8, 2025
b23f6ae
skip mst alloc
tarang-jain Jun 9, 2025
63cebc8
reference issue
tarang-jain Jun 9, 2025
55da78d
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 10, 2025
7090f31
update comment
tarang-jain Jun 10, 2025
6595a38
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 10, 2025
b86b6bd
Merge branch 'build-linkage' of https://github.com/tarang-jain/cuvs i…
tarang-jain Jun 10, 2025
dcd495f
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 16, 2025
6434fee
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 30, 2025
63e65a4
Merge branch 'branch-25.08' into build-linkage
tarang-jain Jun 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions cpp/include/cuvs/cluster/agglomerative.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,4 @@ void single_linkage(
cuvs::cluster::agglomerative::Linkage linkage = cuvs::cluster::agglomerative::Linkage::KNN_GRAPH,
std::optional<int> c = std::make_optional<int>(DEFAULT_CONST_C));

/**
* @}
*/
}; // end namespace cuvs::cluster::agglomerative
30 changes: 30 additions & 0 deletions cpp/include/cuvs/neighbors/reachability.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/device_coo_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/sparse/coo.hpp>
Expand Down Expand Up @@ -73,6 +74,35 @@ void mutual_reachability_graph(
raft::sparse::COO<float, int>& out,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded,
float alpha = 1.0);

namespace helpers {
/**
* Given a mutual reachability graph, connects graph components and build dendrogram.
* Returns mst edges sorted by weight and the linkage.
* @tparam value_idx
* @tparam value_t
* @param[in] handle raft handle for resource reuse
* @param[in] X data points (size m * n)
* @param[in] metric distance metric to use
* @param[in] graph_indptr CSR indices of graph nodes (size m + 1)
* @param[in] graph input graph
* @param[out] out_mst output MST sorted by edge weights (size m - 1)
* @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2)
* @param[out] out_distances distances for output
* @param[out] out_sizes cluster sizes of output
*/
void build_single_linkage_dendrogram(raft::resources const& handle,
raft::device_matrix_view<const float, int, raft::row_major> X,
cuvs::distance::DistanceType metric,
raft::device_vector_view<int, int> graph_indptr,
raft::device_coo_matrix_view<float, int, int, size_t> graph,
raft::device_vector_view<float, int> core_dists,
raft::device_coo_matrix_view<float, int, int, int> out_mst,
raft::device_matrix_view<int, int> dendrogram,
raft::device_vector_view<float, int> out_distances,
raft::device_vector_view<int, int> out_sizes);
} // namespace helpers

/**
* @}
*/
Expand Down
1 change: 1 addition & 0 deletions cpp/src/cluster/detail/single_linkage.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,5 @@ void single_linkage(raft::resources const& handle,
out->n_leaves = m;
out->n_connected_components = 1;
}

}; // namespace cuvs::cluster::agglomerative::detail
1 change: 1 addition & 0 deletions cpp/src/cluster/single_linkage_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ void single_linkage(raft::resources const& handle,
handle, X, dendrogram, labels, metric, n_clusters, c);
}
}

} // namespace cuvs::cluster::agglomerative
76 changes: 76 additions & 0 deletions cpp/src/neighbors/detail/reachability.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/

#pragma once
#include "../../cluster/detail/agglomerative.cuh"
#include "../../cluster/detail/mst.cuh"
#include "../../sparse/neighbors/cross_component_nn.cuh"
#include "./knn_brute_force.cuh"

#include <raft/linalg/unary_op.cuh>
Expand Down Expand Up @@ -254,4 +257,77 @@ void mutual_reachability_graph(const raft::resources& handle,
});
}

/**
* Given a mutual reachability graph and core distances, constructs a linkage over it by computing
* the minimum spanning tree and dendrogram. Returns mst edges sorted by weight and the linkage.
* @tparam value_idx
* @tparam value_t
* @param[in] handle raft handle for resource reuse
* @param[in] X data points (size m * n)
* @param[in] m number of rows
* @param[in] n number of columns
* @param[in] metric distance metric to use
* @param[in] indptr CSR indices of mutual reachability knn graph (size m + 1)
* @param[out] graph_coo input graph
* @param[out] out_mst_src src vertex of MST edges (size m - 1)
* @param[out] out_mst_dst dst vertex of MST eges (size m - 1)
* @param[out] out_mst_weights weights of MST edges (size m - 1)
* @param[out] out_children output dendrogram
* @param[out] out_deltas distances of output
* @param[out] out_sizes cluster sizes of output
*/
template <typename value_idx = int, typename value_t = float, typename nnz_t>
void build_single_linkage_dendrogram(
raft::resources const& handle,
const value_t* X,
size_t m,
size_t n,
cuvs::distance::DistanceType metric,
value_idx* indptr,
raft::device_coo_matrix_view<value_t, value_idx, value_idx, nnz_t> graph_coo,
value_t* core_dists,
value_idx* out_mst_src,
value_idx* out_mst_dst,
value_t* out_mst_weights,
value_idx* out_dendrogram,
value_t* out_deltas,
value_idx* out_sizes)
{
/**
* Construct MST sorted by weights
*/
auto color = raft::make_device_vector<value_idx, value_idx>(handle, static_cast<value_idx>(m));
cuvs::sparse::neighbors::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t>
reduction_op(core_dists, m);

// during knn graph connection
cuvs::cluster::agglomerative::detail::build_sorted_mst(
handle,
X,
indptr,
graph_coo.structure_view().get_cols().data(),
graph_coo.get_elements().data(),
m,
n,
out_mst_src,
out_mst_dst,
out_mst_weights,
color.data_handle(),
graph_coo.structure_view().get_nnz(),
reduction_op,
metric);

/**
* Perform hierarchical labeling
*/
cuvs::cluster::agglomerative::detail::build_dendrogram_host(handle,
out_mst_src,
out_mst_dst,
out_mst_weights,
m - 1,
out_dendrogram,
out_deltas,
out_sizes);
}

} // namespace cuvs::neighbors::detail::reachability
30 changes: 30 additions & 0 deletions cpp/src/neighbors/reachability.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,34 @@ void mutual_reachability_graph(const raft::resources& handle,
core_dists.data_handle(),
out);
}

namespace helpers {
void build_single_linkage_dendrogram(raft::resources const& handle,
raft::device_matrix_view<const float, int, raft::row_major> X,
cuvs::distance::DistanceType metric,
raft::device_vector_view<int, int> graph_indptr,
raft::device_coo_matrix_view<float, int, int, size_t> graph,
raft::device_vector_view<float, int> core_dists,
raft::device_coo_matrix_view<float, int, int, int> out_mst,
raft::device_matrix_view<int, int> dendrogram,
raft::device_vector_view<float, int> out_distances,
raft::device_vector_view<int, int> out_sizes)
{
cuvs::neighbors::detail::reachability::build_single_linkage_dendrogram(
handle,
X.data_handle(),
static_cast<size_t>(X.extent(0)),
static_cast<size_t>(X.extent(1)),
metric,
graph_indptr.data_handle(),
graph,
core_dists.data_handle(),
out_mst.structure_view().get_rows().data(),
out_mst.structure_view().get_cols().data(),
out_mst.get_elements().data(),
dendrogram.data_handle(),
out_distances.data_handle(),
out_sizes.data_handle());
}
} // namespace helpers
} // namespace cuvs::neighbors::reachability
4 changes: 4 additions & 0 deletions cpp/src/sparse/neighbors/cross_component_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ namespace cuvs::sparse::neighbors {
template <typename value_idx, typename value_t>
using FixConnectivitiesRedOp = detail::FixConnectivitiesRedOp<value_idx, value_t>;

template <typename value_idx, typename value_t>
using MutualReachabilityFixConnectivitiesRedOp =
detail::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t>;

/**
* Gets the number of unique components from array of
* colors or labels. This does not assume the components are
Expand Down
87 changes: 87 additions & 0 deletions cpp/src/sparse/neighbors/detail/cross_component_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,93 @@ struct FixConnectivitiesRedOp {
void scatter(const raft::resources& handle, value_idx* map) {}
};

template <typename value_idx, typename value_t>
struct MutualReachabilityFixConnectivitiesRedOp {
value_t* core_dists;
value_idx m;

DI MutualReachabilityFixConnectivitiesRedOp() : m(0) {}

MutualReachabilityFixConnectivitiesRedOp(value_t* core_dists_, value_idx m_)
: core_dists(core_dists_), m(m_){};

typedef typename raft::KeyValuePair<value_idx, value_t> KVP;
DI void operator()(value_idx rit, KVP* out, const KVP& other) const
{
if (rit < m && other.value < std::numeric_limits<value_t>::max()) {
value_t core_dist_rit = core_dists[rit];
value_t core_dist_other = max(core_dist_rit, max(core_dists[other.key], other.value));

value_t core_dist_out;
if (out->key > -1) {
core_dist_out = max(core_dist_rit, max(core_dists[out->key], out->value));
} else {
core_dist_out = out->value;
}

bool smaller = core_dist_other < core_dist_out;
out->key = smaller ? other.key : out->key;
out->value = smaller ? core_dist_other : core_dist_out;
}
}

DI KVP operator()(value_idx rit, const KVP& a, const KVP& b) const
{
if (rit < m && a.key > -1) {
value_t core_dist_rit = core_dists[rit];
value_t core_dist_a = max(core_dist_rit, max(core_dists[a.key], a.value));

value_t core_dist_b;
if (b.key > -1) {
core_dist_b = max(core_dist_rit, max(core_dists[b.key], b.value));
} else {
core_dist_b = b.value;
}

return core_dist_a < core_dist_b ? KVP(a.key, core_dist_a) : KVP(b.key, core_dist_b);
}

return b;
}

DI void init(value_t* out, value_t maxVal) const { *out = maxVal; }
DI void init(KVP* out, value_t maxVal) const
{
out->key = -1;
out->value = maxVal;
}

DI void init_key(value_t& out, value_idx idx) const { return; }
DI void init_key(KVP& out, value_idx idx) const { out.key = idx; }

DI value_t get_value(KVP& out) const { return out.value; }
DI value_t get_value(value_t& out) const { return out; }

void gather(const raft::resources& handle, value_idx* map)
{
auto tmp_core_dists = raft::make_device_vector<value_t>(handle, m);
thrust::gather(raft::resource::get_thrust_policy(handle),
map,
map + m,
core_dists,
tmp_core_dists.data_handle());
raft::copy_async(
core_dists, tmp_core_dists.data_handle(), m, raft::resource::get_cuda_stream(handle));
}

void scatter(const raft::resources& handle, value_idx* map)
{
auto tmp_core_dists = raft::make_device_vector<value_t>(handle, m);
thrust::scatter(raft::resource::get_thrust_policy(handle),
core_dists,
core_dists + m,
map,
tmp_core_dists.data_handle());
raft::copy_async(
core_dists, tmp_core_dists.data_handle(), m, raft::resource::get_cuda_stream(handle));
}
};

/**
* Assumes 3-iterator tuple containing COO rows, cols, and
* a cub keyvalue pair object. Sorts the 3 arrays in
Expand Down
91 changes: 2 additions & 89 deletions cpp/tests/sparse/neighbors/cross_component_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -384,93 +384,6 @@ INSTANTIATE_TEST_CASE_P(ConnectComponentsTest,
ConnectComponentsTestF_Int,
::testing::ValuesIn(fix_conn_inputsf2));

template <typename value_idx, typename value_t>
struct MutualReachabilityFixConnectivitiesRedOp {
value_t* core_dists;
value_idx m;

DI MutualReachabilityFixConnectivitiesRedOp() : m(0) {}

MutualReachabilityFixConnectivitiesRedOp(value_t* core_dists_, value_idx m_)
: core_dists(core_dists_), m(m_){};

typedef typename raft::KeyValuePair<value_idx, value_t> KVP;
DI void operator()(value_idx rit, KVP* out, const KVP& other) const
{
if (rit < m && other.value < std::numeric_limits<value_t>::max()) {
value_t core_dist_rit = core_dists[rit];
value_t core_dist_other = max(core_dist_rit, max(core_dists[other.key], other.value));

value_t core_dist_out;
if (out->key > -1) {
core_dist_out = max(core_dist_rit, max(core_dists[out->key], out->value));
} else {
core_dist_out = out->value;
}

bool smaller = core_dist_other < core_dist_out;
out->key = smaller ? other.key : out->key;
out->value = smaller ? core_dist_other : core_dist_out;
}
}

DI KVP operator()(value_idx rit, const KVP& a, const KVP& b) const
{
if (rit < m && a.key > -1) {
value_t core_dist_rit = core_dists[rit];
value_t core_dist_a = max(core_dist_rit, max(core_dists[a.key], a.value));

value_t core_dist_b;
if (b.key > -1) {
core_dist_b = max(core_dist_rit, max(core_dists[b.key], b.value));
} else {
core_dist_b = b.value;
}

return core_dist_a < core_dist_b ? KVP(a.key, core_dist_a) : KVP(b.key, core_dist_b);
}

return b;
}

DI void init(value_t* out, value_t maxVal) const { *out = maxVal; }
DI void init(KVP* out, value_t maxVal) const
{
out->key = -1;
out->value = maxVal;
}

DI void init_key(value_t& out, value_idx idx) const { return; }
DI void init_key(KVP& out, value_idx idx) const { out.key = idx; }

DI value_t get_value(KVP& out) const { return out.value; }
DI value_t get_value(value_t& out) const { return out; }

void gather(const raft::resources& handle, value_idx* map)
{
auto tmp_core_dists = raft::make_device_vector<value_t>(handle, m);
thrust::gather(raft::resource::get_thrust_policy(handle),
map,
map + m,
core_dists,
tmp_core_dists.data_handle());
raft::copy_async(
core_dists, tmp_core_dists.data_handle(), m, raft::resource::get_cuda_stream(handle));
}

void scatter(const raft::resources& handle, value_idx* map)
{
auto tmp_core_dists = raft::make_device_vector<value_t>(handle, m);
thrust::scatter(raft::resource::get_thrust_policy(handle),
core_dists,
core_dists + m,
map,
tmp_core_dists.data_handle());
raft::copy_async(
core_dists, tmp_core_dists.data_handle(), m, raft::resource::get_cuda_stream(handle));
}
};

template <typename value_t, typename value_idx>
struct ConnectComponentsMutualReachabilityInputs {
value_idx n_row;
Expand Down Expand Up @@ -518,8 +431,8 @@ class ConnectComponentsEdgesTest
/**
* 3. cross_component_nn to fix connectivities
*/
MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t> red_op(core_dists.data(),
params.n_row);
cuvs::sparse::neighbors::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t> red_op(
core_dists.data(), params.n_row);

cuvs::sparse::neighbors::cross_component_nn<value_idx, value_t>(handle,
out_edges_unbatched,
Expand Down