Skip to content

Commit d963eb4

Browse files
fix a scale loads not being predicated
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent f67c068 commit d963eb4

File tree

6 files changed

+144
-101
lines changed

6 files changed

+144
-101
lines changed

csrc/cutlass_extensions/gemm/collective/collective_builder.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ struct CollectiveBuilder<
7171
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
7272
KernelTmaWarpSpecializedCooperative,
7373
KernelPtrArrayTmaWarpSpecializedCooperative,
74-
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>>;
74+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
7575
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
7676
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
7777

csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp

Lines changed: 106 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ struct CollectiveMma<
121121
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
122122
using PipelineParams = typename MainloopPipeline::Params;
123123

124-
// Two threads per CTA are producers (1 for operand tile and 1 for scales)
125-
static constexpr int NumProducerThreadEvents = 2;
124+
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
125+
static constexpr int NumProducerThreadEvents = 33;
126126

127127
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
128128
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
@@ -148,8 +148,7 @@ struct CollectiveMma<
148148
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
149149

150150
// Block scaling gmem-to-smem copy atom
151-
using BlockScaleCopyTypeA = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>;
152-
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeA>, ElementBlockScale>;
151+
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
153152
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
154153

155154
// Block scaling smem layout
@@ -326,17 +325,20 @@ struct CollectiveMma<
326325
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
327326
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
328327

328+
constexpr auto scales_m = Int<ScaleMsPerTile>{};
329+
auto tM = get<2>(gA_mkl.shape());
330+
auto tN = get<2>(gB_nkl.shape());
331+
auto tK = get<3>(gA_mkl.shape());
332+
329333
// Make the tiled views of scale tensors
330-
auto scaleA_shape = make_shape(get<2>(gA_mkl.shape()), Int<ScaleMsPerTile>{}, get<3>(gA_mkl.shape()), get<4>(gA_mkl.shape())); // (m,ScaleMsPerTile,k,l)
331-
auto scale_dA = make_stride(Int<ScaleMsPerTile>{}, Int<1>{}, get<3>(gA_mkl.shape()) * Int<ScaleMsPerTile>{}, get<2>(gA_mkl.shape()) * get<3>(gA_mkl.shape()) * Int<ScaleMsPerTile>{});
332-
auto scaleA_layout = make_layout(scaleA_shape, scale_dA);
333-
auto scaleB_shape = make_shape(get<2>(gB_nkl.shape()), get<3>(gB_nkl.shape()), get<4>(gB_nkl.shape())); // (n,k,l)
334-
auto scale_dB = make_stride(get<3>(gB_nkl.shape()), Int<1>{}, get<2>(gB_nkl.shape()) * get<3>(gB_nkl.shape()));
335-
auto scaleB_layout = make_layout(scaleB_shape, scale_dB);
334+
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
335+
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
336+
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
337+
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
336338

337339
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
338340
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
339-
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (m,ScaleMsPerTile,k,l)
341+
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
340342
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
341343

342344
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
@@ -363,102 +365,120 @@ struct CollectiveMma<
363365
int lane_predicate = cute::elect_one_sync();
364366

365367
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
366-
if (lane_predicate) {
367-
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
368-
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
369-
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
370-
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
368+
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
369+
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
370+
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
371+
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
371372

372-
//
373-
// Prepare the TMA loads for A and B
374-
//
373+
//
374+
// Prepare the TMA loads for A and B
375+
//
375376

376-
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
377-
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
377+
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
378+
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
378379

379-
Tensor gA_mkl = get<0>(load_inputs);
380-
Tensor gB_nkl = get<1>(load_inputs);
380+
Tensor gA_mkl = get<0>(load_inputs);
381+
Tensor gB_nkl = get<1>(load_inputs);
381382

382-
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
383-
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
383+
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
384+
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
384385

385-
// Partition the inputs based on the current block coordinates.
386-
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
387-
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
388-
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
386+
// Partition the inputs based on the current block coordinates.
387+
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
388+
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
389+
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
389390

390391

391-
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
392-
Tensor mScaleA_mkl = get<2>(load_inputs);
393-
Tensor mScaleB_nkl = get<3>(load_inputs);
392+
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
393+
Tensor mScaleA_mkl = get<2>(load_inputs);
394+
Tensor mScaleB_nkl = get<3>(load_inputs);
395+
auto scales_m = get<0>(mScaleA_mkl.shape());
394396

395-
Tensor gScaleA = mScaleA_mkl(m_coord,_,_,l_coord); // (1,ScaleMsPerTile,k,1)
396-
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
397-
398-
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleMsPerTile>>>{}); // (1,ScaleMsPerTile,1)
399-
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
400-
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
401-
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
402-
403-
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
404-
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
405-
406-
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
407-
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
397+
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
398+
399+
Tensor gScaleA = local_tile(
400+
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
401+
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
402+
Tensor cScaleA = local_tile(
403+
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
404+
make_coord(m_coord,_,l_coord));
405+
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
406+
407+
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
408+
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
409+
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
410+
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
411+
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
412+
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
413+
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
414+
415+
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
416+
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
417+
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
418+
419+
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
420+
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
408421

409-
// Applies the mapping from block_tma_a
410-
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
411-
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
422+
// Applies the mapping from block_tma_a
423+
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
424+
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
412425

413-
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
414-
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
426+
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
427+
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
415428

416-
uint16_t mcast_mask_a = 0;
417-
uint16_t mcast_mask_b = 0;
429+
uint16_t mcast_mask_a = 0;
430+
uint16_t mcast_mask_b = 0;
418431

419-
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
420-
// Maps the tile -> block, value
421-
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
422-
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
423-
for (int n = 0; n < size<1>(block_layout); ++n) {
424-
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
425-
}
432+
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
433+
// Maps the tile -> block, value
434+
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
435+
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
436+
for (int n = 0; n < size<1>(block_layout); ++n) {
437+
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
426438
}
439+
}
427440

428-
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
429-
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
430-
for (int m = 0; m < size<0>(block_layout); ++m) {
431-
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
432-
}
441+
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
442+
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
443+
for (int m = 0; m < size<0>(block_layout); ++m) {
444+
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
433445
}
446+
}
434447

435-
// Mainloop
436-
CUTLASS_PRAGMA_NO_UNROLL
437-
for ( ; k_tile_count > 0; --k_tile_count) {
438-
// LOCK smem_pipe_write for _writing_
439-
pipeline.producer_acquire(smem_pipe_write);
448+
// Allocate predicate tensors for a_scales (since we can't gaurantee that
449+
// all scales are valid, since we could have a partial tiles along M)
450+
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
451+
#pragma unroll
452+
for (int i = 0; i < size(tApA_ScaleA); ++i) {
453+
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
454+
}
455+
456+
// Mainloop
457+
CUTLASS_PRAGMA_NO_UNROLL
458+
for ( ; k_tile_count > 0; --k_tile_count) {
459+
// LOCK smem_pipe_write for _writing_
460+
pipeline.producer_acquire(smem_pipe_write);
440461

441-
//
442-
// Copy gmem to smem for *k_tile_iter
443-
//
444-
int write_stage = smem_pipe_write.index();
445-
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
446-
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
462+
//
463+
// Copy gmem to smem for *k_tile_iter
464+
//
465+
int write_stage = smem_pipe_write.index();
466+
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
467+
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
447468

448-
// Copy operands A and B from global memory to shared memory
449-
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
450-
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
469+
// Copy operands A and B from global memory to shared memory
470+
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
471+
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
451472

452-
// Copy scale tensors from global memory to shared memory
453-
copy(scale_copy_a, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
454-
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
455-
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
473+
// Copy scale tensors from global memory to shared memory
474+
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
475+
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
476+
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
456477

457-
++k_tile_iter;
478+
++k_tile_iter;
458479

459-
// Advance smem_pipe_write
460-
++smem_pipe_write;
461-
}
480+
// Advance smem_pipe_write
481+
++smem_pipe_write;
462482
}
463483
}
464484

csrc/cutlass_extensions/gemm/dispatch_policy.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
2727
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_,
2828
KernelSchedule> {
2929
static_assert(
30-
cute::is_same_v<KernelSchedule,
31-
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<
32-
ScaleGranularityM>>,
30+
cute::is_same_v<
31+
KernelSchedule,
32+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
33+
ScaleGranularityM>>,
3334
"KernelSchedule must be one of the warp specialized policies");
3435
};
3536

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_c3x_sm90_fp8_dispatch.cuh

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
5959
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
6060
using ClusterShape = Shape<_1, _2, _1>;
6161

62-
using KernelSchedule =
63-
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<
62+
using KernelSchedule = cutlass::gemm::
63+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
6464
GroupSizeM_>;
6565
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
6666
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
@@ -124,14 +124,26 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
124124
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
125125
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
126126

127+
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
128+
// being 1 (i.e. a row or column vector)
129+
auto is_contiguous_vector = [](const torch::Tensor& t) {
130+
auto t_sizes = t.sizes();
131+
return t.is_contiguous() &&
132+
(t.dim() == 1 ||
133+
(t.dim() == 2 &&
134+
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
135+
};
136+
127137
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
128138
// we don't have to deal with enforcing implicit layouts
129139
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
130140
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
131-
TORCH_CHECK(a_scales.stride(0) == 1, "a_scales must be M major");
141+
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
142+
"a_scales must be M major");
132143
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
133144
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
134-
TORCH_CHECK(b_scales.stride(0) == 1, "b_scales must be K major");
145+
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
146+
"b_scales must be K major");
135147

136148
uint32_t mma_promotion_interval = 4;
137149
typename GemmKernel::MainloopArguments mainloop_args{

csrc/quantization/machete/machete_mainloop.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ struct MacheteCollectiveMma {
272272
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
273273

274274
using PipelineParams = typename MainloopPipeline::Params;
275+
276+
// One threads per CTA are producers (1 for operand tile)
277+
static constexpr int NumProducerThreadEvents = 1;
278+
275279
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
276280
shape<1>(SmemLayoutAtomScale{})));
277281

0 commit comments

Comments
 (0)