Skip to content

Commit f07eb68

Browse files
committed
fix
1 parent 850ad01 commit f07eb68

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

mlx/backend/cuda/ternary.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// Copyright © 2025 Apple Inc.
2-
32
#include "mlx/backend/common/ternary.h"
43
#include "mlx/backend/cuda/device.h"
54
#include "mlx/backend/cuda/device/ternary_ops.cuh"
@@ -80,7 +79,6 @@ void ternary_op_gpu_inplace(
8079
const std::vector<array>& inputs,
8180
array& out,
8281
const Stream& s) {
83-
assert(inputs.size() > 1);
8482
const auto& a = inputs[0];
8583
const auto& b = inputs[1];
8684
const auto& c = inputs[2];
@@ -94,7 +92,7 @@ void ternary_op_gpu_inplace(
9492
encoder.set_input_array(c);
9593
encoder.set_output_array(out);
9694
encoder.launch_kernel([&](cudaStream_t stream) {
97-
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, {
95+
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
9896
using DType = cuda_type_t<CTYPE>;
9997

10098
auto topt = get_ternary_op_type(a, b, c);
@@ -110,7 +108,7 @@ void ternary_op_gpu_inplace(
110108
int ndim = shape.size();
111109
if (ndim <= 3) {
112110
MLX_SWITCH_1_2_3(ndim, NDIM, {
113-
auto kernel = &cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
111+
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
114112
auto [num_blocks, block_dims] =
115113
get_launch_args(kernel, out, large);
116114
kernel<<<num_blocks, block_dims, 0, stream>>>(

0 commit comments

Comments
 (0)