@@ -362,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
362
362
a_batch_strides.back (),
363
363
b_batch_strides.back ());
364
364
365
+ encoder.set_input_array (a);
366
+ encoder.set_input_array (b);
367
+ encoder.set_output_array (out);
368
+ auto nbatch = batch_count / batch_shape.back ();
369
+ if (nbatch == 1 ) {
370
+ matmul.run (encoder, out.data <int8_t >(), a.data <int8_t >(), b.data <int8_t >());
371
+ return ;
372
+ }
373
+
365
374
ContiguousIterator a_it (batch_shape, a_batch_strides, batch_shape.size () - 1 );
366
375
ContiguousIterator b_it (batch_shape, b_batch_strides, batch_shape.size () - 1 );
367
- for (size_t i = 0 ; i < batch_count / batch_shape. back () ; ++i) {
376
+ for (size_t i = 0 ; i < nbatch ; ++i) {
368
377
matmul.run (
369
378
encoder,
370
379
out.data <int8_t >() + out.itemsize () * i * batch_shape.back () * M * N,
@@ -448,10 +457,21 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
448
457
b_batch_strides.back (),
449
458
c_batch_strides.back ());
450
459
460
+ encoder.set_input_array (a);
461
+ encoder.set_input_array (b);
462
+ encoder.set_input_array (c);
463
+ encoder.set_output_array (out);
464
+
465
+ auto nbatch = batch_count / batch_shape.back ();
466
+ if (nbatch == 1 ) {
467
+ matmul.run (encoder, out.data <int8_t >(), a.data <int8_t >(), b.data <int8_t >(), c.data <int8_t >(), alpha_, beta_);
468
+ return ;
469
+ }
470
+
451
471
ContiguousIterator a_it (batch_shape, a_batch_strides, batch_shape.size () - 1 );
452
472
ContiguousIterator b_it (batch_shape, b_batch_strides, batch_shape.size () - 1 );
453
473
ContiguousIterator c_it (batch_shape, c_batch_strides, batch_shape.size () - 1 );
454
- for (size_t i = 0 ; i < batch_count / batch_shape. back () ; ++i) {
474
+ for (size_t i = 0 ; i < nbatch ; ++i) {
455
475
matmul.run (
456
476
encoder,
457
477
out.data <int8_t >() + out.itemsize () * i * batch_shape.back () * M * N,
0 commit comments