Skip to content

Commit 7a81d80

Browse files
LucasWilkinsonsumitd2
authored andcommitted
[Bugfix] Machete garbage results for some models (large K dim) (vllm-project#9212)
Signed-off-by: Sumit Dubey <[email protected]>
1 parent f7ccb08 commit 7a81d80

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

csrc/quantization/machete/machete_mainloop.cuh

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -591,24 +591,27 @@ struct MacheteCollectiveMma {
591591
tma_load_b = make_tma_copy_B(
592592
make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));
593593

594+
int32_t scale_k =
595+
(ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0;
596+
int32_t group_size = (ModeHasScales) ? args.group_size : 0;
597+
594598
if constexpr (ModeHasScales) {
595-
tma_load_scale = make_tma_copy_scale(make_logical_tensor(
596-
args.ptr_S, make_shape(M, args.group_size, L), args.dS));
599+
tma_load_scale = make_tma_copy_scale(
600+
make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS));
597601
}
598602

599603
if constexpr (KernelConversionMode ==
600604
ConversionMode::ConvertAndScaleWithZero) {
601-
tma_load_zero = make_tma_copy_zero(make_logical_tensor(
602-
args.ptr_Z, make_shape(M, args.group_size, L), args.dS));
605+
tma_load_zero = make_tma_copy_zero(
606+
make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS));
603607
}
604608

605-
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
606-
return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0};
607-
} else if constexpr (ModeHasScales) {
608-
auto scale_k = (K + args.group_size - 1) / args.group_size;
609-
609+
if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
610+
KernelConversionMode == ConversionMode::ConvertAndScale ||
611+
KernelConversionMode ==
612+
ConversionMode::ConvertAndScaleWithZero) {
610613
return {tma_load_a, tma_load_b, tma_load_scale,
611-
tma_load_zero, scale_k, args.group_size};
614+
tma_load_zero, scale_k, group_size};
612615
} else {
613616
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
614617
"Conversion mode not handled in to_underlying_arguments.");

tests/kernels/test_machete_gemm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
(1, 128, 128),
2525
(1, 512, 1024),
2626
(1, 4096, 4096),
27+
(1, 8192, 28672),
2728
(13, 8192, 4096),
2829
(26, 4096, 8192),
29-
(1, 4096, 4096),
30+
(64, 4096, 4096),
31+
(64, 8192, 28672),
3032
(257, 128, 4096),
3133
(257, 4224, 4160),
3234
(257, 4096, 4096),
33-
(64, 4096, 4096),
3435
(1024, 4096, 8192),
3536
(1024, 8192, 4096),
3637
]

0 commit comments

Comments
 (0)