@@ -591,24 +591,27 @@ struct MacheteCollectiveMma {
591
591
tma_load_b = make_tma_copy_B (
592
592
make_logical_tensor (ptr_B, make_shape (N, K, L), args.dB ));
593
593
594
+ int32_t scale_k =
595
+ (ModeHasScales) ? (K + args.group_size - 1 ) / args.group_size : 0 ;
596
+ int32_t group_size = (ModeHasScales) ? args.group_size : 0 ;
597
+
594
598
if constexpr (ModeHasScales) {
595
- tma_load_scale = make_tma_copy_scale (make_logical_tensor (
596
- args.ptr_S , make_shape (M, args. group_size , L), args.dS ));
599
+ tma_load_scale = make_tma_copy_scale (
600
+ make_logical_tensor ( args.ptr_S , make_shape (M, scale_k , L), args.dS ));
597
601
}
598
602
599
603
if constexpr (KernelConversionMode ==
600
604
ConversionMode::ConvertAndScaleWithZero) {
601
- tma_load_zero = make_tma_copy_zero (make_logical_tensor (
602
- args.ptr_Z , make_shape (M, args. group_size , L), args.dS ));
605
+ tma_load_zero = make_tma_copy_zero (
606
+ make_logical_tensor ( args.ptr_Z , make_shape (M, scale_k , L), args.dS ));
603
607
}
604
608
605
- if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
606
- return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0 , 0 };
607
- } else if constexpr (ModeHasScales) {
608
- auto scale_k = (K + args.group_size - 1 ) / args.group_size ;
609
-
609
+ if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
610
+ KernelConversionMode == ConversionMode::ConvertAndScale ||
611
+ KernelConversionMode ==
612
+ ConversionMode::ConvertAndScaleWithZero) {
610
613
return {tma_load_a, tma_load_b, tma_load_scale,
611
- tma_load_zero, scale_k, args. group_size };
614
+ tma_load_zero, scale_k, group_size};
612
615
} else {
613
616
static_assert (cutlass::detail::dependent_false<KernelSchedule>,
614
617
" Conversion mode not handled in to_underlying_arguments." );
0 commit comments