Skip to content

[CUDA] Fix gemv regression #2445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion mlx/backend/cuda/device/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,18 @@ struct alignas(sizeof(T) * N) AlignedVector {
};

template <int N, typename T>
inline __device__ bool is_aligned(T* x) {
inline __host__ __device__ bool is_aligned(T* x) {
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
}

template <int N, typename T>
inline __device__ AlignedVector<T, N> unsafe_load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}

template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
Expand Down Expand Up @@ -101,6 +109,13 @@ inline __device__ AlignedVector<T, N> load_vector(
}
}

template <int N, typename T>
inline __device__ void
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}

template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
Expand Down
15 changes: 10 additions & 5 deletions mlx/backend/cuda/gemms/gemv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
auto local_mat =
unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum +=
Expand Down Expand Up @@ -127,9 +128,13 @@ void gemv(
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
int n_per_t = 4;
while (K % (n_per_t * WARP_SIZE) != 0) {
n_per_t >>= 1;
int n_per_t;
if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {
n_per_t = 4;
} else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {
n_per_t = 2;
} else {
n_per_t = 1;
}
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) {
Expand Down
13 changes: 10 additions & 3 deletions python/tests/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __gemm_test(
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))

def test_matmul_unaligned(self):
if not mx.metal.is_available():
if not mx.is_available(mx.gpu):
return

for dtype in self.dtypes:
Expand All @@ -61,8 +61,15 @@ def test_matmul_unaligned(self):
shape_b = (dim + p, dim + p)
self.__gemm_test(shape_a, shape_b, np_dtype)

def test_matvec_unaligned(self):
a = mx.random.normal(shape=(4, 128))
b = mx.random.normal(shape=(129,))[1:]
out = a @ b
np_out = np.array(a) @ np.array(b)
self.assertTrue(np.allclose(out, np_out))

def test_matmul_shapes(self):
if not mx.metal.is_available():
if not mx.is_available(mx.gpu):
return

shapes = [
Expand Down Expand Up @@ -1274,7 +1281,7 @@ def segmented_mm_ref(a, b, s):
def test_gemv_gemm_same_precision(self):
mx.random.seed(0)
N = 256
if mx.metal.is_available():
if mx.is_available(mx.gpu):
t = mx.bfloat16
a = mx.random.normal([1, N]).astype(t)
b = mx.concatenate([a, a], axis=0).astype(t)
Expand Down