2
2
3
3
#include " mlx/backend/common/ternary.h"
4
4
#include " mlx/backend/cuda/device.h"
5
+ #include " mlx/backend/cuda/device/ternary_ops.cuh"
5
6
#include " mlx/backend/cuda/kernel_utils.cuh"
6
- #include " mlx/backend/cuda/kernels/ternary_ops.cuh"
7
7
#include " mlx/dtype_utils.h"
8
8
#include " mlx/primitives.h"
9
9
@@ -69,7 +69,7 @@ __global__ void ternary_g(
69
69
b_strides.data (),
70
70
c_strides.data (),
71
71
ndim);
72
- out[index] = Op{}(a[a_idx], b[b_idx]);
72
+ out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx] );
73
73
}
74
74
}
75
75
@@ -79,7 +79,6 @@ template <typename Op>
79
79
void ternary_op_gpu_inplace (
80
80
const std::vector<array>& inputs,
81
81
array& out,
82
- std::string_view op,
83
82
const Stream& s) {
84
83
assert (inputs.size () > 1 );
85
84
const auto & a = inputs[0 ];
@@ -162,20 +161,19 @@ template <typename Op>
162
161
void ternary_op_gpu (
163
162
const std::vector<array>& inputs,
164
163
array& out,
165
- std::string_view op,
166
164
const Stream& s) {
167
165
auto & a = inputs[0 ];
168
166
auto & b = inputs[1 ];
169
167
auto & c = inputs[2 ];
170
168
auto topt = get_ternary_op_type (a, b, c);
171
169
set_ternary_op_output_data (a, b, c, out, topt);
172
- ternary_op_gpu_inplace<Op>(inputs, out, op, s);
170
+ ternary_op_gpu_inplace<Op>(inputs, out, s);
173
171
}
174
172
175
173
void Select::eval_gpu (const std::vector<array>& inputs, array& out) {
176
174
nvtx3::scoped_range r (" select::eval_gpu" );
177
175
auto & s = out.primitive ().stream ();
178
- ternary_op_gpu<cu::Select>(inputs, out, get_primitive_string ( this ), s);
176
+ ternary_op_gpu<cu::Select>(inputs, out, s);
179
177
}
180
178
181
179
} // namespace mlx::core
0 commit comments