1
1
// Copyright © 2025 Apple Inc.
2
-
3
2
#include " mlx/backend/common/ternary.h"
4
3
#include " mlx/backend/cuda/device.h"
5
4
#include " mlx/backend/cuda/device/ternary_ops.cuh"
@@ -80,7 +79,6 @@ void ternary_op_gpu_inplace(
80
79
const std::vector<array>& inputs,
81
80
array& out,
82
81
const Stream& s) {
83
- assert (inputs.size () > 1 );
84
82
const auto & a = inputs[0 ];
85
83
const auto & b = inputs[1 ];
86
84
const auto & c = inputs[2 ];
@@ -94,7 +92,7 @@ void ternary_op_gpu_inplace(
94
92
encoder.set_input_array (c);
95
93
encoder.set_output_array (out);
96
94
encoder.launch_kernel ([&](cudaStream_t stream) {
97
- MLX_SWITCH_ALL_TYPES (a .dtype (), CTYPE, {
95
+ MLX_SWITCH_ALL_TYPES (out .dtype (), CTYPE, {
98
96
using DType = cuda_type_t <CTYPE>;
99
97
100
98
auto topt = get_ternary_op_type (a, b, c);
@@ -110,7 +108,7 @@ void ternary_op_gpu_inplace(
110
108
int ndim = shape.size ();
111
109
if (ndim <= 3 ) {
112
110
MLX_SWITCH_1_2_3 (ndim, NDIM, {
113
- auto kernel = & cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
111
+ auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
114
112
auto [num_blocks, block_dims] =
115
113
get_launch_args (kernel, out, large);
116
114
kernel<<<num_blocks, block_dims, 0 , stream>>> (
0 commit comments