@@ -121,8 +121,8 @@ struct CollectiveMma<
121
121
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
122
122
using PipelineParams = typename MainloopPipeline::Params;
123
123
124
- // Two threads per CTA are producers (1 for operand tile and 1 for scales)
125
- static constexpr int NumProducerThreadEvents = 2 ;
124
+ // Two threads per CTA are producers (1 for operand tile and 32 for scales)
125
+ static constexpr int NumProducerThreadEvents = 33 ;
126
126
127
127
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0 >(TileShape{}) : ScaleGranularityM_;
128
128
static constexpr int ScaleMsPerTile = size<0 >(TileShape{}) / ScaleGranularityM;
@@ -148,8 +148,7 @@ struct CollectiveMma<
148
148
cute::conditional_t < ::cutlass::gemm::detail::is_major<0 ,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
149
149
150
150
// Block scaling gmem-to-smem copy atom
151
- using BlockScaleCopyTypeA = cute::uint_byte_t <cute::min(static_cast <int >(sizeof (ElementBlockScale)) * ScaleMsPerTile, 16 )>;
152
- using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<BlockScaleCopyTypeA>, ElementBlockScale>;
151
+ using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
153
152
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
154
153
155
154
// Block scaling smem layout
@@ -326,17 +325,20 @@ struct CollectiveMma<
326
325
Tensor gA_mkl = local_tile (mA_mkl , TileShape{}, make_coord (_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
327
326
Tensor gB_nkl = local_tile (mB_nkl , TileShape{}, make_coord (_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
328
327
328
+ constexpr auto scales_m = Int<ScaleMsPerTile>{};
329
+ auto tM = get<2 >(gA_mkl .shape ());
330
+ auto tN = get<2 >(gB_nkl .shape ());
331
+ auto tK = get<3 >(gA_mkl .shape ());
332
+
329
333
// Make the tiled views of scale tensors
330
- auto scaleA_shape = make_shape (get<2 >(gA_mkl .shape ()), Int<ScaleMsPerTile>{}, get<3 >(gA_mkl .shape ()), get<4 >(gA_mkl .shape ())); // (m,ScaleMsPerTile,k,l)
331
- auto scale_dA = make_stride (Int<ScaleMsPerTile>{}, Int<1 >{}, get<3 >(gA_mkl .shape ()) * Int<ScaleMsPerTile>{}, get<2 >(gA_mkl .shape ()) * get<3 >(gA_mkl .shape ()) * Int<ScaleMsPerTile>{});
332
- auto scaleA_layout = make_layout (scaleA_shape, scale_dA);
333
- auto scaleB_shape = make_shape (get<2 >(gB_nkl .shape ()), get<3 >(gB_nkl .shape ()), get<4 >(gB_nkl .shape ())); // (n,k,l)
334
- auto scale_dB = make_stride (get<3 >(gB_nkl .shape ()), Int<1 >{}, get<2 >(gB_nkl .shape ()) * get<3 >(gB_nkl .shape ()));
335
- auto scaleB_layout = make_layout (scaleB_shape, scale_dB);
334
+ auto scaleA_shape = make_shape (M / ScaleGranularityM, tK, L); // (scale_m,k,l)
335
+ auto scaleA_layout = make_ordered_layout (scaleA_shape, Step<_0, _1, _2>{});
336
+ auto scaleB_shape = make_shape (tN, tK, L); // (n,k,l)
337
+ auto scaleB_layout = make_ordered_layout (scaleB_shape, Step<_1, _0, _2>{});
336
338
337
339
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
338
340
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
339
- Tensor mScaleA_mkl = make_tensor (make_gmem_ptr (mainloop_params.ptr_scale_A ), scaleA_layout); // (m,ScaleMsPerTile ,k,l)
341
+ Tensor mScaleA_mkl = make_tensor (make_gmem_ptr (mainloop_params.ptr_scale_A ), scaleA_layout); // (scale_m ,k,l)
340
342
Tensor mScaleB_nkl = make_tensor (make_gmem_ptr (mainloop_params.ptr_scale_B ), scaleB_layout); // (n,k,l)
341
343
342
344
return cute::make_tuple (gA_mkl , gB_nkl , mScaleA_mkl , mScaleB_nkl );
@@ -363,102 +365,120 @@ struct CollectiveMma<
363
365
int lane_predicate = cute::elect_one_sync ();
364
366
365
367
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
366
- if (lane_predicate) {
367
- Tensor sA = make_tensor (make_smem_ptr (shared_tensors.smem_A .data ()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
368
- Tensor sB = make_tensor (make_smem_ptr (shared_tensors.smem_B .data ()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
369
- Tensor sScaleA = make_tensor (cute::make_smem_ptr (shared_tensors.smem_scale_A .data ()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
370
- Tensor sScaleB = make_tensor (cute::make_smem_ptr (shared_tensors.smem_scale_B .data ()), SmemLayoutScaleB{}); // (k)
368
+ Tensor sA = make_tensor (make_smem_ptr (shared_tensors.smem_A .data ()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
369
+ Tensor sB = make_tensor (make_smem_ptr (shared_tensors.smem_B .data ()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
370
+ Tensor sScaleA = make_tensor (cute::make_smem_ptr (shared_tensors.smem_scale_A .data ()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
371
+ Tensor sScaleB = make_tensor (cute::make_smem_ptr (shared_tensors.smem_scale_B .data ()), SmemLayoutScaleB{}); // (k)
371
372
372
- //
373
- // Prepare the TMA loads for A and B
374
- //
373
+ //
374
+ // Prepare the TMA loads for A and B
375
+ //
375
376
376
- constexpr uint32_t cluster_shape_x = get<0 >(ClusterShape ());
377
- uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
377
+ constexpr uint32_t cluster_shape_x = get<0 >(ClusterShape ());
378
+ uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
378
379
379
- Tensor gA_mkl = get<0 >(load_inputs);
380
- Tensor gB_nkl = get<1 >(load_inputs);
380
+ Tensor gA_mkl = get<0 >(load_inputs);
381
+ Tensor gB_nkl = get<1 >(load_inputs);
381
382
382
- auto block_tma_a = mainloop_params.tma_load_a .get_slice (cluster_local_block_id.y );
383
- auto block_tma_b = mainloop_params.tma_load_b .get_slice (cluster_local_block_id.x );
383
+ auto block_tma_a = mainloop_params.tma_load_a .get_slice (cluster_local_block_id.y );
384
+ auto block_tma_b = mainloop_params.tma_load_b .get_slice (cluster_local_block_id.x );
384
385
385
- // Partition the inputs based on the current block coordinates.
386
- auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
387
- Tensor gA = gA_mkl (_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
388
- Tensor gB = gB_nkl (_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
386
+ // Partition the inputs based on the current block coordinates.
387
+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
388
+ Tensor gA = gA_mkl (_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
389
+ Tensor gB = gB_nkl (_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
389
390
390
391
391
- // Block scaling: load_scale has scaling tensors in global memory which are not tiled
392
- Tensor mScaleA_mkl = get<2 >(load_inputs);
393
- Tensor mScaleB_nkl = get<3 >(load_inputs);
392
+ // Block scaling: load_scale has scaling tensors in global memory which are not tiled
393
+ Tensor mScaleA_mkl = get<2 >(load_inputs);
394
+ Tensor mScaleB_nkl = get<3 >(load_inputs);
395
+ auto scales_m = get<0 >(mScaleA_mkl .shape ());
394
396
395
- Tensor gScaleA = mScaleA_mkl (m_coord,_,_,l_coord); // (1,ScaleMsPerTile,k,1)
396
- Tensor gScaleB = mScaleB_nkl (n_coord,_,l_coord); // (1,k,1)
397
-
398
- TiledCopy scale_copy_a = make_tiled_copy (SmemBlockScalingCopyAtomA{}, Layout<Shape<_1>>{}, Layout<Shape<Int<ScaleMsPerTile>>>{}); // (1,ScaleMsPerTile,1)
399
- TiledCopy scale_copy_b = make_tiled_copy (SmemBlockScalingCopyAtomB{}, Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
400
- ThrCopy thr_scale_copy_a = scale_copy_a.get_slice (threadIdx.x );
401
- ThrCopy thr_scale_copy_b = scale_copy_b.get_slice (threadIdx.x );
402
-
403
- Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S (gScaleA );
404
- Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D (sScaleA );
405
-
406
- Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S (gScaleB );
407
- Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D (sScaleB );
397
+ Tensor cScaleA_mkl = make_identity_tensor (mScaleA_mkl .shape ());
398
+
399
+ Tensor gScaleA = local_tile (
400
+ mScaleA_mkl , make_tile (Int<ScaleMsPerTile>{}),
401
+ make_coord (m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
402
+ Tensor cScaleA = local_tile (
403
+ cScaleA_mkl, make_tile (Int<ScaleMsPerTile>{}),
404
+ make_coord (m_coord,_,l_coord));
405
+ Tensor gScaleB = mScaleB_nkl (n_coord,_,l_coord); // (1,k,1)
406
+
407
+ // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
408
+ TiledCopy scale_copy_a = make_tiled_copy (SmemBlockScalingCopyAtomA{},
409
+ Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
410
+ TiledCopy scale_copy_b = make_tiled_copy (SmemBlockScalingCopyAtomB{},
411
+ Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
412
+ ThrCopy thr_scale_copy_a = scale_copy_a.get_slice (threadIdx.x );
413
+ ThrCopy thr_scale_copy_b = scale_copy_b.get_slice (threadIdx.x );
414
+
415
+ Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S (gScaleA );
416
+ Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S (cScaleA);
417
+ Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D (sScaleA );
418
+
419
+ Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S (gScaleB );
420
+ Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D (sScaleB );
408
421
409
- // Applies the mapping from block_tma_a
410
- Tensor tAgA = block_tma_a.partition_S (gA ); // (TMA,TMA_M,TMA_K,k)
411
- Tensor tAsA = block_tma_a.partition_D (sA ); // (TMA,TMA_M,TMA_K,PIPE)
422
+ // Applies the mapping from block_tma_a
423
+ Tensor tAgA = block_tma_a.partition_S (gA ); // (TMA,TMA_M,TMA_K,k)
424
+ Tensor tAsA = block_tma_a.partition_D (sA ); // (TMA,TMA_M,TMA_K,PIPE)
412
425
413
- Tensor tBgB = block_tma_b.partition_S (gB ); // (TMA,TMA_N,TMA_K,k)
414
- Tensor tBsB = block_tma_b.partition_D (sB ); // (TMA,TMA_N,TMA_K,PIPE)
426
+ Tensor tBgB = block_tma_b.partition_S (gB ); // (TMA,TMA_N,TMA_K,k)
427
+ Tensor tBsB = block_tma_b.partition_D (sB ); // (TMA,TMA_N,TMA_K,PIPE)
415
428
416
- uint16_t mcast_mask_a = 0 ;
417
- uint16_t mcast_mask_b = 0 ;
429
+ uint16_t mcast_mask_a = 0 ;
430
+ uint16_t mcast_mask_b = 0 ;
418
431
419
- // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
420
- // Maps the tile -> block, value
421
- if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
422
- auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
423
- for (int n = 0 ; n < size<1 >(block_layout); ++n) {
424
- mcast_mask_a |= (uint16_t (1 ) << block_layout (cluster_local_block_id.x ,n,Int<0 >{}));
425
- }
432
+ // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
433
+ // Maps the tile -> block, value
434
+ if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
435
+ auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
436
+ for (int n = 0 ; n < size<1 >(block_layout); ++n) {
437
+ mcast_mask_a |= (uint16_t (1 ) << block_layout (cluster_local_block_id.x ,n,Int<0 >{}));
426
438
}
439
+ }
427
440
428
- if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
429
- auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
430
- for (int m = 0 ; m < size<0 >(block_layout); ++m) {
431
- mcast_mask_b |= (uint16_t (1 ) << block_layout (m,cluster_local_block_id.y ,Int<0 >{}));
432
- }
441
+ if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
442
+ auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
443
+ for (int m = 0 ; m < size<0 >(block_layout); ++m) {
444
+ mcast_mask_b |= (uint16_t (1 ) << block_layout (m,cluster_local_block_id.y ,Int<0 >{}));
433
445
}
446
+ }
434
447
435
- // Mainloop
436
- CUTLASS_PRAGMA_NO_UNROLL
437
- for ( ; k_tile_count > 0 ; --k_tile_count) {
438
- // LOCK smem_pipe_write for _writing_
439
- pipeline.producer_acquire (smem_pipe_write);
448
+ // Allocate predicate tensors for a_scales (since we can't gaurantee that
449
+ // all scales are valid, since we could have a partial tiles along M)
450
+ Tensor tApA_ScaleA = make_tensor<bool >(shape (tAsA_ScaleA (_,_,0 )));
451
+ #pragma unroll
452
+ for (int i = 0 ; i < size (tApA_ScaleA); ++i) {
453
+ tApA_ScaleA (i) = get<0 >(tAcA_ScaleA (i)) < scales_m;
454
+ }
455
+
456
+ // Mainloop
457
+ CUTLASS_PRAGMA_NO_UNROLL
458
+ for ( ; k_tile_count > 0 ; --k_tile_count) {
459
+ // LOCK smem_pipe_write for _writing_
460
+ pipeline.producer_acquire (smem_pipe_write);
440
461
441
- //
442
- // Copy gmem to smem for *k_tile_iter
443
- //
444
- int write_stage = smem_pipe_write.index ();
445
- using BarrierType = typename MainloopPipeline::ProducerBarrierType;
446
- BarrierType* tma_barrier = pipeline.producer_get_barrier (smem_pipe_write);
462
+ //
463
+ // Copy gmem to smem for *k_tile_iter
464
+ //
465
+ int write_stage = smem_pipe_write.index ();
466
+ using BarrierType = typename MainloopPipeline::ProducerBarrierType;
467
+ BarrierType* tma_barrier = pipeline.producer_get_barrier (smem_pipe_write);
447
468
448
- // Copy operands A and B from global memory to shared memory
449
- copy (mainloop_params.tma_load_a .with (*tma_barrier, mcast_mask_a), tAgA (_,_,_,*k_tile_iter), tAsA (_,_,_,write_stage));
450
- copy (mainloop_params.tma_load_b .with (*tma_barrier, mcast_mask_b), tBgB (_,_,_,*k_tile_iter), tBsB (_,_,_,write_stage));
469
+ // Copy operands A and B from global memory to shared memory
470
+ if (lane_predicate) copy (mainloop_params.tma_load_a .with (*tma_barrier, mcast_mask_a), tAgA (_,_,_,*k_tile_iter), tAsA (_,_,_,write_stage));
471
+ if (lane_predicate) copy (mainloop_params.tma_load_b .with (*tma_barrier, mcast_mask_b), tBgB (_,_,_,*k_tile_iter), tBsB (_,_,_,write_stage));
451
472
452
- // Copy scale tensors from global memory to shared memory
453
- copy (scale_copy_a, tAgA_ScaleA (_,_,*k_tile_iter), tAsA_ScaleA (_,_,write_stage));
454
- copy (scale_copy_b, tBgB_ScaleB (_,*k_tile_iter), tBsB_ScaleB (_,write_stage));
455
- pipeline.producer_commit (smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
473
+ // Copy scale tensors from global memory to shared memory
474
+ copy_if (scale_copy_a, tApA_ScaleA , tAgA_ScaleA (_,_,*k_tile_iter), tAsA_ScaleA (_,_,write_stage));
475
+ copy (scale_copy_b, tBgB_ScaleB (_,*k_tile_iter), tBsB_ScaleB (_,write_stage));
476
+ pipeline.producer_commit (smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
456
477
457
- ++k_tile_iter;
478
+ ++k_tile_iter;
458
479
459
- // Advance smem_pipe_write
460
- ++smem_pipe_write;
461
- }
480
+ // Advance smem_pipe_write
481
+ ++smem_pipe_write;
462
482
}
463
483
}
464
484
0 commit comments