Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 15 additions & 43 deletions cpp/src/preprocessing/quantize/detail/scalar.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
#include <cuvs/preprocessing/quantize/scalar.hpp>
#include <raft/core/operators.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/sample_rows.cuh>
#include <raft/random/rng.cuh>
#include <raft/random/sample_without_replacement.cuh>
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include <thrust/system/omp/execution_policy.h>
Expand Down Expand Up @@ -72,10 +72,11 @@ struct quantize_op {
}
};

template <typename T>
std::tuple<T, T> quantile_min_max(raft::resources const& res,
raft::device_matrix_view<const T, int64_t> dataset,
double quantile)
template <typename T, typename IdxT = int64_t, typename accessor>
std::tuple<T, T> quantile_min_max(
raft::resources const& res,
raft::mdspan<const T, raft::matrix_extent<IdxT>, raft::row_major, accessor> dataset,
double quantile)
{
// settings for quantile approximation
constexpr size_t max_num_samples = 1000000;
Expand All @@ -85,14 +86,15 @@ std::tuple<T, T> quantile_min_max(raft::resources const& res,

// select subsample
raft::random::RngState rng(seed);
size_t n_elements = dataset.extent(0) * dataset.extent(1);
size_t subset_size = std::min(max_num_samples, n_elements);
auto subset = raft::make_device_vector<T>(res, subset_size);
auto dataset_view = raft::make_device_vector_view<const T>(dataset.data_handle(), n_elements);
raft::random::sample_without_replacement(
res, rng, dataset_view, std::nullopt, subset.view(), std::nullopt);

// quantile / sort and pick for now
size_t n_rows = dataset.extent(0);
size_t dim = dataset.extent(1);
size_t n_sample_rows = std::min<size_t>(std::ceil(max_num_samples / dim), n_rows);

// select subsample rows (this returns device data for both device and host input)
auto subset = raft::matrix::sample_rows(res, rng, dataset, (IdxT)n_sample_rows);

// quantile / sort element-wise and pick for now
size_t subset_size = n_sample_rows * dim;
thrust::sort(raft::resource::get_thrust_policy(res),
subset.data_handle(),
subset.data_handle() + subset_size);
Expand All @@ -105,39 +107,9 @@ std::tuple<T, T> quantile_min_max(raft::resources const& res,
raft::update_host(&(minmax_h[0]), subset.data_handle() + pos_min, 1, stream);
raft::update_host(&(minmax_h[1]), subset.data_handle() + pos_max, 1, stream);
raft::resource::sync_stream(res);

return {minmax_h[0], minmax_h[1]};
}

template <typename T>
std::tuple<T, T> quantile_min_max(raft::resources const& res,
raft::host_matrix_view<const T, int64_t> dataset,
double quantile)
{
// settings for quantile approximation
constexpr size_t max_num_samples = 1000000;
constexpr int seed = 137;

// select subsample
std::mt19937 rng(seed);
size_t n_elements = dataset.extent(0) * dataset.extent(1);
size_t subset_size = std::min(max_num_samples, n_elements);
std::vector<T> subset;
std::sample(dataset.data_handle(),
dataset.data_handle() + n_elements,
std::back_inserter(subset),
subset_size,
rng);

// quantile / sort and pick for now
thrust::sort(thrust::omp::par, subset.data(), subset.data() + subset_size, fp_lt<T>);
double half_quantile_pos = (0.5 + 0.5 * quantile) * subset_size;
int pos_max = std::ceil(half_quantile_pos) - 1;
int pos_min = subset_size - pos_max - 1;

return {subset[pos_min], subset[pos_max]};
}

template <typename T>
cuvs::preprocessing::quantize::scalar::quantizer<T> train(
raft::resources const& res,
Expand Down