Skip to content

Commit 2188199

Browse files
authored
[CUDA] ternary with select op (#2283)
* cuda ternary with select op * comment + fix * fix
1 parent aa07429 commit 2188199

File tree

5 files changed

+231
-1
lines changed

5 files changed

+231
-1
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ target_sources(
3434
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
3535
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
3636
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
37+
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
3738
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
3839
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
namespace mlx::core::cu {
4+
5+
struct Select {
6+
template <typename T>
7+
__device__ T operator()(bool condition, T x, T y) {
8+
return condition ? x : y;
9+
}
10+
};
11+
12+
} // namespace mlx::core::cu

mlx/backend/cuda/device/utils.cuh

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,27 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
162162
return cuda::std::make_tuple(a_loc, b_loc);
163163
}
164164

165+
template <int NDIM, typename IdxT = int64_t>
166+
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
167+
IdxT elem,
168+
const int* shape,
169+
const int64_t* a_strides,
170+
const int64_t* b_strides,
171+
const int64_t* c_strides) {
172+
IdxT a_loc = 0;
173+
IdxT b_loc = 0;
174+
IdxT c_loc = 0;
175+
#pragma unroll
176+
for (int i = NDIM - 1; i >= 0; --i) {
177+
int dim_idx = elem % shape[i];
178+
a_loc += dim_idx * a_strides[i];
179+
b_loc += dim_idx * b_strides[i];
180+
c_loc += dim_idx * c_strides[i];
181+
elem /= shape[i];
182+
}
183+
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
184+
}
185+
165186
// Optimized version when ndim is larger than 4.
166187
template <typename IdxT = int64_t>
167188
inline __host__ __device__ IdxT
@@ -191,6 +212,26 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
191212
return cuda::std::make_tuple(a_loc, b_loc);
192213
}
193214

215+
template <typename IdxT = int64_t>
216+
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
217+
IdxT elem,
218+
const int* shape,
219+
const int64_t* a_strides,
220+
const int64_t* b_strides,
221+
const int64_t* c_strides,
222+
int ndim) {
223+
auto [a_loc, b_loc, c_loc] =
224+
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
225+
for (int i = ndim - 1; i >= 3; --i) {
226+
int dim_idx = elem % shape[i];
227+
a_loc += dim_idx * a_strides[i];
228+
b_loc += dim_idx * b_strides[i];
229+
c_loc += dim_idx * c_strides[i];
230+
elem /= shape[i];
231+
}
232+
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
233+
}
234+
194235
///////////////////////////////////////////////////////////////////////////////
195236
// Elem to loc in a loop utils
196237
///////////////////////////////////////////////////////////////////////////////

mlx/backend/cuda/primitives.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ NO_GPU(QuantizedMatmul)
9191
NO_GPU(Scan)
9292
NO_GPU(Scatter)
9393
NO_GPU(ScatterAxis)
94-
NO_GPU(Select)
9594
NO_GPU_MULTI(SVD)
9695
NO_GPU(Inverse)
9796
NO_GPU(Cholesky)

mlx/backend/cuda/ternary.cu

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// Copyright © 2025 Apple Inc.
2+
#include "mlx/backend/common/ternary.h"
3+
#include "mlx/backend/cuda/device.h"
4+
#include "mlx/backend/cuda/device/ternary_ops.cuh"
5+
#include "mlx/backend/cuda/kernel_utils.cuh"
6+
#include "mlx/dtype_utils.h"
7+
#include "mlx/primitives.h"
8+
9+
#include <cooperative_groups.h>
10+
#include <nvtx3/nvtx3.hpp>
11+
12+
namespace mlx::core {
13+
14+
namespace cu {
15+
16+
namespace cg = cooperative_groups;
17+
18+
template <typename Op, typename T, typename IdxT>
19+
__global__ void
20+
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
21+
IdxT index = cg::this_grid().thread_rank();
22+
if (index < size) {
23+
out[index] = Op{}(a[index], b[index], c[index]);
24+
}
25+
}
26+
27+
template <typename Op, typename T, typename IdxT, int NDIM>
28+
__global__ void ternary_g_nd(
29+
const bool* a,
30+
const T* b,
31+
const T* c,
32+
T* out,
33+
IdxT size,
34+
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
35+
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
36+
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides,
37+
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_strides) {
38+
IdxT index = cg::this_grid().thread_rank();
39+
if (index < size) {
40+
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
41+
index,
42+
shape.data(),
43+
a_strides.data(),
44+
b_strides.data(),
45+
c_strides.data());
46+
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
47+
}
48+
}
49+
50+
template <typename Op, typename T, typename IdxT>
51+
__global__ void ternary_g(
52+
const bool* a,
53+
const T* b,
54+
const T* c,
55+
T* out,
56+
IdxT size,
57+
const __grid_constant__ Shape shape,
58+
const __grid_constant__ Strides a_strides,
59+
const __grid_constant__ Strides b_strides,
60+
const __grid_constant__ Strides c_strides,
61+
int ndim) {
62+
IdxT index = cg::this_grid().thread_rank();
63+
if (index < size) {
64+
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d(
65+
index,
66+
shape.data(),
67+
a_strides.data(),
68+
b_strides.data(),
69+
c_strides.data(),
70+
ndim);
71+
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
72+
}
73+
}
74+
75+
} // namespace cu
76+
77+
template <typename Op>
78+
void ternary_op_gpu_inplace(
79+
const std::vector<array>& inputs,
80+
array& out,
81+
const Stream& s) {
82+
const auto& a = inputs[0];
83+
const auto& b = inputs[1];
84+
const auto& c = inputs[2];
85+
if (out.size() == 0) {
86+
return;
87+
}
88+
89+
auto& encoder = cu::get_command_encoder(s);
90+
encoder.set_input_array(a);
91+
encoder.set_input_array(b);
92+
encoder.set_input_array(c);
93+
encoder.set_output_array(out);
94+
encoder.launch_kernel([&](cudaStream_t stream) {
95+
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
96+
using DType = cuda_type_t<CTYPE>;
97+
98+
auto topt = get_ternary_op_type(a, b, c);
99+
if (topt == TernaryOpType::General) {
100+
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
101+
auto& a_strides = strides[0];
102+
auto& b_strides = strides[1];
103+
auto& c_strides = strides[2];
104+
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
105+
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
106+
MLX_SWITCH_BOOL(large, LARGE, {
107+
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
108+
int ndim = shape.size();
109+
if (ndim <= 3) {
110+
MLX_SWITCH_1_2_3(ndim, NDIM, {
111+
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
112+
auto [num_blocks, block_dims] =
113+
get_launch_args(kernel, out, large);
114+
kernel<<<num_blocks, block_dims, 0, stream>>>(
115+
a.data<bool>(),
116+
b.data<DType>(),
117+
c.data<DType>(),
118+
out.data<DType>(),
119+
out.data_size(),
120+
const_param<NDIM>(shape),
121+
const_param<NDIM>(a_strides),
122+
const_param<NDIM>(b_strides),
123+
const_param<NDIM>(c_strides));
124+
});
125+
} else {
126+
auto kernel = cu::ternary_g<Op, DType, IdxT>;
127+
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
128+
kernel<<<num_blocks, block_dims, 0, stream>>>(
129+
a.data<bool>(),
130+
b.data<DType>(),
131+
c.data<DType>(),
132+
out.data<DType>(),
133+
out.data_size(),
134+
const_param(shape),
135+
const_param(a_strides),
136+
const_param(b_strides),
137+
const_param(c_strides),
138+
ndim);
139+
}
140+
});
141+
} else {
142+
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
143+
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
144+
auto kernel = cu::ternary_v<Op, DType, IdxT>;
145+
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
146+
kernel<<<num_blocks, block_dims, 0, stream>>>(
147+
a.data<bool>(),
148+
b.data<DType>(),
149+
c.data<DType>(),
150+
out.data<DType>(),
151+
out.data_size());
152+
});
153+
}
154+
});
155+
});
156+
}
157+
158+
template <typename Op>
159+
void ternary_op_gpu(
160+
const std::vector<array>& inputs,
161+
array& out,
162+
const Stream& s) {
163+
auto& a = inputs[0];
164+
auto& b = inputs[1];
165+
auto& c = inputs[2];
166+
auto topt = get_ternary_op_type(a, b, c);
167+
set_ternary_op_output_data(a, b, c, out, topt);
168+
ternary_op_gpu_inplace<Op>(inputs, out, s);
169+
}
170+
171+
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
172+
nvtx3::scoped_range r("select::eval_gpu");
173+
auto& s = out.primitive().stream();
174+
ternary_op_gpu<cu::Select>(inputs, out, s);
175+
}
176+
177+
} // namespace mlx::core

0 commit comments

Comments
 (0)