Skip to content

Commit f48e9aa

Browse files
authored
Add support for float16 to the python pairwise distance api (#547)
Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #547
1 parent 89ebf15 commit f48e9aa

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

cpp/src/distance/pairwise_distance_c.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
namespace {
3131

32-
template <typename T>
32+
template <typename T, typename DistT>
3333
void _pairwise_distance(cuvsResources_t res,
3434
DLManagedTensor* x_tensor,
3535
DLManagedTensor* y_tensor,
@@ -40,7 +40,7 @@ void _pairwise_distance(cuvsResources_t res,
4040
auto res_ptr = reinterpret_cast<raft::resources*>(res);
4141

4242
using mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
43-
using distances_mdspan_type = raft::device_matrix_view<T, int64_t, raft::row_major>;
43+
using distances_mdspan_type = raft::device_matrix_view<DistT, int64_t, raft::row_major>;
4444

4545
auto x_mds = cuvs::core::from_dlpack<mdspan_type>(x_tensor);
4646
auto y_mds = cuvs::core::from_dlpack<mdspan_type>(y_tensor);
@@ -71,9 +71,14 @@ extern "C" cuvsError_t cuvsPairwiseDistance(cuvsResources_t res,
7171
}
7272

7373
if (x_dt.bits == 32) {
74-
_pairwise_distance<float>(res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
74+
_pairwise_distance<float, float>(
75+
res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
76+
} else if (x_dt.bits == 16) {
77+
_pairwise_distance<half, float>(
78+
res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
7579
} else if (x_dt.bits == 64) {
76-
_pairwise_distance<double>(res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
80+
_pairwise_distance<double, double>(
81+
res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
7782
} else {
7883
RAFT_FAIL("Unsupported DLtensor dtype: %d and bits: %d", x_dt.code, x_dt.bits);
7984
}

python/cuvs/cuvs/distance/distance.pyx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0,
100100
n = y_cai.shape[0]
101101

102102
if out is None:
103-
out = device_ndarray.empty((m, n), dtype=y_cai.dtype)
103+
output_dtype = y_cai.dtype
104+
if np.issubdtype(y_cai.dtype, np.float16):
105+
output_dtype = np.float32
106+
out = device_ndarray.empty((m, n), dtype=output_dtype)
104107
out_cai = wrap_array(out)
105108

106109
x_k = x_cai.shape[1]
@@ -119,7 +122,7 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0,
119122
y_dt = y_cai.dtype
120123
d_dt = out_cai.dtype
121124

122-
if x_dt != y_dt or x_dt != d_dt:
125+
if x_dt != y_dt:
123126
raise ValueError("Inputs must have the same dtypes")
124127

125128
cdef cydlpack.DLManagedTensor* x_dlpack = \

python/cuvs/cuvs/test/test_distance.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
],
4141
)
4242
@pytest.mark.parametrize("inplace", [True, False])
43-
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
43+
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16])
4444
def test_distance(n_rows, n_cols, inplace, metric, dtype):
4545
input1 = np.random.random_sample((n_rows, n_cols))
4646
input1 = np.asarray(input1).astype(dtype)
@@ -55,7 +55,10 @@ def test_distance(n_rows, n_cols, inplace, metric, dtype):
5555
norm = np.sum(input1, axis=1)
5656
input1 = (input1.T / norm).T
5757

58-
output = np.zeros((n_rows, n_rows), dtype=dtype)
58+
output_dtype = dtype
59+
if np.issubdtype(dtype, np.float16):
60+
output_dtype = np.float32
61+
output = np.zeros((n_rows, n_rows), dtype=output_dtype)
5962

6063
if metric == "inner_product":
6164
expected = np.matmul(input1, input1.T)
@@ -76,4 +79,8 @@ def test_distance(n_rows, n_cols, inplace, metric, dtype):
7679

7780
actual = output_device.copy_to_host()
7881

79-
assert np.allclose(expected, actual, atol=1e-3, rtol=1e-3)
82+
tol = 1e-3
83+
if np.issubdtype(dtype, np.float16):
84+
tol = 1e-1
85+
86+
assert np.allclose(expected, actual, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)