Skip to content

Commit 1a0e884

Browse files
committed
fix adding inputs arrays in matmul / srot
1 parent 72e21b7 commit 1a0e884

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

mlx/backend/cuda/allocator.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// Copyright © 2025 Apple Inc.
22

3-
#include "mlx/utils.h"
43
#include "mlx/backend/cuda/allocator.h"
54
#include "mlx/backend/cuda/utils.h"
65
#include "mlx/backend/cuda/worker.h"

mlx/backend/cuda/matmul.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
362362
a_batch_strides.back(),
363363
b_batch_strides.back());
364364

365+
encoder.set_input_array(a);
366+
encoder.set_input_array(b);
367+
encoder.set_output_array(out);
368+
auto nbatch = batch_count / batch_shape.back();
369+
if (nbatch == 1) {
370+
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
371+
return;
372+
}
373+
365374
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
366375
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
367-
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
376+
for (size_t i = 0; i < nbatch; ++i) {
368377
matmul.run(
369378
encoder,
370379
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
@@ -448,10 +457,21 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
448457
b_batch_strides.back(),
449458
c_batch_strides.back());
450459

460+
encoder.set_input_array(a);
461+
encoder.set_input_array(b);
462+
encoder.set_input_array(c);
463+
encoder.set_output_array(out);
464+
465+
auto nbatch = batch_count / batch_shape.back();
466+
if (nbatch == 1) {
467+
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>(), c.data<int8_t>(), alpha_, beta_);
468+
return;
469+
}
470+
451471
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
452472
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
453473
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
454-
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
474+
for (size_t i = 0; i < nbatch; ++i) {
455475
matmul.run(
456476
encoder,
457477
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,

mlx/backend/cuda/sort.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
7979
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
8080
array out = out_;
8181
auto& encoder = cu::get_command_encoder(s);
82-
encoder.set_input_array(in);
83-
encoder.set_output_array(out);
84-
8582
if (axis < 0) {
8683
axis += in.ndim();
8784
}
@@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
106103
in.flags());
107104
}
108105

106+
encoder.set_input_array(in);
107+
encoder.set_output_array(out);
109108
encoder.launch_kernel([&](cudaStream_t stream) {
110109
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
111110
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {

0 commit comments

Comments
 (0)