Skip to content

Commit 55d4a0b

Browse files
committed
comment + fix
1 parent 8ab01e7 commit 55d4a0b

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

mlx/backend/cuda/ternary.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
#include "mlx/backend/common/ternary.h"
44
#include "mlx/backend/cuda/device.h"
5+
#include "mlx/backend/cuda/device/ternary_ops.cuh"
56
#include "mlx/backend/cuda/kernel_utils.cuh"
6-
#include "mlx/backend/cuda/kernels/ternary_ops.cuh"
77
#include "mlx/dtype_utils.h"
88
#include "mlx/primitives.h"
99

@@ -69,7 +69,7 @@ __global__ void ternary_g(
6969
b_strides.data(),
7070
c_strides.data(),
7171
ndim);
72-
out[index] = Op{}(a[a_idx], b[b_idx]);
72+
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
7373
}
7474
}
7575

@@ -79,7 +79,6 @@ template <typename Op>
7979
void ternary_op_gpu_inplace(
8080
const std::vector<array>& inputs,
8181
array& out,
82-
std::string_view op,
8382
const Stream& s) {
8483
assert(inputs.size() > 1);
8584
const auto& a = inputs[0];
@@ -162,20 +161,19 @@ template <typename Op>
162161
void ternary_op_gpu(
163162
const std::vector<array>& inputs,
164163
array& out,
165-
std::string_view op,
166164
const Stream& s) {
167165
auto& a = inputs[0];
168166
auto& b = inputs[1];
169167
auto& c = inputs[2];
170168
auto topt = get_ternary_op_type(a, b, c);
171169
set_ternary_op_output_data(a, b, c, out, topt);
172-
ternary_op_gpu_inplace<Op>(inputs, out, op, s);
170+
ternary_op_gpu_inplace<Op>(inputs, out, s);
173171
}
174172

175173
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
176174
nvtx3::scoped_range r("select::eval_gpu");
177175
auto& s = out.primitive().stream();
178-
ternary_op_gpu<cu::Select>(inputs, out, get_primitive_string(this), s);
176+
ternary_op_gpu<cu::Select>(inputs, out, s);
179177
}
180178

181179
} // namespace mlx::core

0 commit comments

Comments
 (0)