Skip to content

Commit 7539ea9

Browse files
authored
[Opt] Expose the detail::popc as public API (rapidsai#2346)
- For resolving the issue of cuVS: rapidsai/cuvs#158 Authors: - rhdong (https://github.com/rhdong) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#2346
1 parent b7f6a84 commit 7539ea9

File tree

7 files changed

+344
-8
lines changed

7 files changed

+344
-8
lines changed

cpp/bench/prims/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,13 @@ endfunction()
7575

7676
if(BUILD_PRIMS_BENCH)
7777
ConfigureBench(
78-
NAME CORE_BENCH PATH core/bitset.cu core/copy.cu main.cpp
78+
NAME
79+
CORE_BENCH
80+
PATH
81+
core/bitset.cu
82+
core/copy.cu
83+
core/popc.cu
84+
main.cpp
7985
)
8086

8187
ConfigureBench(

cpp/bench/prims/core/popc.cu

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <common/benchmark.hpp>
18+
19+
#include <raft/core/popc.hpp>
20+
21+
namespace raft::bench::core {
22+
23+
template <typename index_t>
24+
struct PopcInputs {
25+
index_t n_rows;
26+
index_t n_cols;
27+
float sparsity;
28+
};
29+
30+
template <typename index_t>
31+
inline auto operator<<(std::ostream& os, const PopcInputs<index_t>& params) -> std::ostream&
32+
{
33+
os << params.n_rows << "#" << params.n_cols << "#" << params.sparsity;
34+
return os;
35+
}
36+
37+
template <typename index_t, typename bits_t = uint32_t>
38+
struct popc_bench : public fixture {
39+
popc_bench(const PopcInputs<index_t>& p)
40+
: params(p),
41+
n_element(raft::ceildiv(params.n_rows * params.n_cols, index_t(sizeof(bits_t) * 8))),
42+
bits_d{raft::make_device_vector<bits_t, index_t>(res, n_element)},
43+
nnz_actual_d{raft::make_device_scalar<index_t>(res, 0)}
44+
{
45+
}
46+
47+
index_t create_bitmap(index_t m, index_t n, float sparsity, std::vector<bits_t>& bitmap)
48+
{
49+
index_t total = static_cast<index_t>(m * n);
50+
index_t num_ones = static_cast<index_t>((total * 1.0f) * sparsity);
51+
index_t res = num_ones;
52+
53+
for (auto& item : bitmap) {
54+
item = static_cast<bits_t>(0);
55+
}
56+
57+
std::random_device rd;
58+
std::mt19937 gen(rd());
59+
std::uniform_int_distribution<index_t> dis(0, total - 1);
60+
61+
while (num_ones > 0) {
62+
index_t index = dis(gen);
63+
64+
bits_t& element = bitmap[index / (8 * sizeof(bits_t))];
65+
index_t bit_position = index % (8 * sizeof(bits_t));
66+
67+
if (((element >> bit_position) & 1) == 0) {
68+
element |= (static_cast<index_t>(1) << bit_position);
69+
num_ones--;
70+
}
71+
}
72+
return res;
73+
}
74+
void run_benchmark(::benchmark::State& state) override
75+
{
76+
std::ostringstream label_stream;
77+
label_stream << params;
78+
state.SetLabel(label_stream.str());
79+
80+
std::vector<bits_t> bits_h(n_element);
81+
auto stream = raft::resource::get_cuda_stream(res);
82+
83+
create_bitmap(params.n_rows, params.n_cols, params.sparsity, bits_h);
84+
update_device(bits_d.data_handle(), bits_h.data(), bits_h.size(), stream);
85+
86+
resource::sync_stream(res);
87+
88+
loop_on_state(state, [this]() {
89+
auto bits_view =
90+
raft::make_device_vector_view<const bits_t, index_t>(bits_d.data_handle(), bits_d.size());
91+
92+
index_t max_len = params.n_rows * params.n_cols;
93+
auto max_len_view = raft::make_host_scalar_view<index_t>(&max_len);
94+
auto nnz_actual_view =
95+
nnz_actual_d.view(); // raft::make_device_scalar_view<index_t>(nnz_actual_d.data_handle());
96+
raft::popc(this->handle, bits_view, max_len_view, nnz_actual_view);
97+
});
98+
}
99+
100+
private:
101+
raft::resources res;
102+
PopcInputs<index_t> params;
103+
index_t n_element;
104+
105+
raft::device_vector<bits_t, index_t> bits_d;
106+
raft::device_scalar<index_t> nnz_actual_d;
107+
};
108+
109+
template <typename index_t>
110+
const std::vector<PopcInputs<index_t>> popc_input_vecs{
111+
{2, 131072, 0.4}, {8, 131072, 0.5}, {16, 131072, 0.2}, {2, 8192, 0.4}, {16, 8192, 0.5},
112+
{128, 8192, 0.2}, {1024, 8192, 0.1}, {1024, 8192, 0.1}, {1024, 8192, 0.1}, {1024, 8192, 0.1},
113+
114+
{1024, 8192, 0.1}, {1024, 8192, 0.1}, {1024, 8192, 0.1}, {1024, 8192, 0.1},
115+
116+
{1024, 8192, 0.4}, {1024, 8192, 0.5}, {1024, 8192, 0.2}, {1024, 8192, 0.4}, {1024, 8192, 0.5},
117+
{1024, 8192, 0.2}, {1024, 8192, 0.4}, {1024, 8192, 0.5}, {1024, 8192, 0.2}, {1024, 8192, 0.4},
118+
{1024, 8192, 0.5}, {1024, 8192, 0.2},
119+
120+
{1024, 8192, 0.5}, {1024, 8192, 0.2}, {1024, 8192, 0.4}, {1024, 8192, 0.5}, {1024, 8192, 0.2},
121+
{1024, 8192, 0.4}, {1024, 8192, 0.5}, {1024, 8192, 0.2}};
122+
123+
using PopcBenchI64 = popc_bench<int64_t>;
124+
125+
RAFT_BENCH_REGISTER(PopcBenchI64, "", popc_input_vecs<int64_t>);
126+
127+
} // namespace raft::bench::core

cpp/include/raft/core/bitset.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
#pragma once
1818

1919
#include <raft/core/bitset.hpp>
20-
#include <raft/core/detail/popc.cuh>
2120
#include <raft/core/device_container_policy.hpp>
2221
#include <raft/core/device_mdarray.hpp>
22+
#include <raft/core/popc.hpp>
2323
#include <raft/core/resource/thrust_policy.hpp>
2424
#include <raft/core/resources.hpp>
2525
#include <raft/linalg/map.cuh>
@@ -167,9 +167,10 @@ template <typename bitset_t, typename index_t>
167167
void bitset<bitset_t, index_t>::count(const raft::resources& res,
168168
raft::device_scalar_view<index_t> count_gpu_scalar)
169169
{
170+
auto max_len = raft::make_host_scalar_view<index_t>(&bitset_len_);
170171
auto values =
171172
raft::make_device_vector_view<const bitset_t, index_t>(bitset_.data(), n_elements());
172-
raft::detail::popc(res, values, bitset_len_, count_gpu_scalar);
173+
raft::popc(res, values, max_len, count_gpu_scalar);
173174
}
174175

175176
} // end namespace raft::core

cpp/include/raft/core/detail/popc.cuh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <raft/core/detail/mdspan_util.cuh>
1919
#include <raft/core/device_mdarray.hpp>
20+
#include <raft/core/host_mdspan.hpp>
2021
#include <raft/core/resources.hpp>
2122
#include <raft/linalg/coalesced_reduction.cuh>
2223

@@ -28,15 +29,15 @@ namespace raft::detail {
2829
* @tparam value_t the value type of the vector.
2930
* @tparam index_t the index type of vector and scalar.
3031
*
31-
* @param[in] res raft handle for managing expensive resources
32-
* @param[in] values Number of row in the matrix.
32+
* @param[in] res RAFT handle for managing expensive resources
33+
* @param[in] values Device vector view containing the values to be processed.
3334
* @param[in] max_len Maximum number of bits to count.
34-
* @param[out] counter Number of bits that are set to 1.
35+
* @param[out] counter Device scalar view to store the number of bits that are set to 1.
3536
*/
3637
template <typename value_t, typename index_t>
3738
void popc(const raft::resources& res,
3839
device_vector_view<value_t, index_t> values,
39-
index_t max_len,
40+
raft::host_scalar_view<index_t> max_len,
4041
raft::device_scalar_view<index_t> counter)
4142
{
4243
auto values_size = values.size();
@@ -46,7 +47,7 @@ void popc(const raft::resources& res,
4647

4748
static constexpr index_t len_per_item = sizeof(value_t) * 8;
4849

49-
value_t tail_len = (max_len % len_per_item);
50+
value_t tail_len = (max_len[0] % len_per_item);
5051
value_t tail_mask = tail_len ? (value_t)((value_t{1} << tail_len) - value_t{1}) : ~value_t{0};
5152
raft::linalg::coalesced_reduction(
5253
res,

cpp/include/raft/core/popc.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
#include <raft/core/detail/popc.cuh>
19+
namespace raft {
20+
21+
/**
22+
* @brief Count the number of bits that are set to 1 in a vector.
23+
*
24+
* @tparam value_t the value type of the vector.
25+
* @tparam index_t the index type of vector and scalar.
26+
*
27+
* @param[in] res RAFT handle for managing expensive resources
28+
* @param[in] values Device vector view containing the values to be processed.
29+
* @param[in] max_len Host scalar view to store the Maximum number of bits to count.
30+
* @param[out] counter Device scalar view to store the number of bits that are set to 1.
31+
*/
32+
template <typename value_t, typename index_t>
33+
void popc(const raft::resources& res,
34+
device_vector_view<value_t, index_t> values,
35+
raft::host_scalar_view<index_t> max_len,
36+
raft::device_scalar_view<index_t> counter)
37+
{
38+
detail::popc(res, values, max_len, counter);
39+
}
40+
41+
} // namespace raft

cpp/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ if(BUILD_TESTS)
122122
core/math_host.cpp
123123
core/operators_device.cu
124124
core/operators_host.cpp
125+
core/popc.cu
125126
core/handle.cpp
126127
core/interruptible.cu
127128
core/nvtx.cpp

0 commit comments

Comments
 (0)