Skip to content

Commit 81902d2

Browse files
Hongbosherlocklulmer
authored andcommitted
[Kernel]Add streamK for block-quantized CUTLASS kernels (vllm-project#12978)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent cf98d66 commit 81902d2

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,18 @@ static inline cute::Shape<int, int, int, int> get_problem_shape(
3030
}
3131

3232
template <typename GemmKernel>
33-
void cutlass_gemm_caller(torch::Device device,
34-
cute::Shape<int, int, int, int> prob_shape,
35-
typename GemmKernel::MainloopArguments mainloop_args,
36-
typename GemmKernel::EpilogueArguments epilogue_args) {
33+
void cutlass_gemm_caller(
34+
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
35+
typename GemmKernel::MainloopArguments mainloop_args,
36+
typename GemmKernel::EpilogueArguments epilogue_args,
37+
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
38+
cutlass::KernelHardwareInfo hw_info;
3739
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
38-
prob_shape, mainloop_args, epilogue_args};
40+
prob_shape,
41+
mainloop_args,
42+
epilogue_args,
43+
hw_info,
44+
scheduler};
3945

4046
// Launch the CUTLASS GEMM kernel.
4147
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ namespace vllm {
2222

2323
using namespace cute;
2424

25-
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
26-
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
25+
template <typename SchedulerType, typename OutType, int GroupSizeM_,
26+
int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
27+
class ClusterShape = Shape<_1, _2, _1>>
2728
struct cutlass_3x_gemm_fp8_blockwise {
2829
using GroupSizeM = Int<GroupSizeM_>;
2930
using GroupSizeN = Int<GroupSizeN_>;
@@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
8485

8586
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
8687
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
87-
cutlass::gemm::PersistentScheduler>>;
88+
SchedulerType>>;
8889

8990
struct GemmKernel : public KernelType {};
9091

@@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
150151
typename GemmKernel::EpilogueArguments epilogue_args{
151152
{}, c_ptr, c_stride, c_ptr, c_stride};
152153

154+
typename GemmKernel::TileSchedulerArguments scheduler;
155+
156+
static constexpr bool UsesStreamKScheduler =
157+
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
158+
cutlass::gemm::StreamKScheduler>;
159+
160+
if constexpr (UsesStreamKScheduler) {
161+
using DecompositionMode = typename cutlass::gemm::kernel::detail::
162+
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
163+
using ReductionMode = typename cutlass::gemm::kernel::detail::
164+
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
165+
166+
scheduler.decomposition_mode = DecompositionMode::StreamK;
167+
scheduler.reduction_mode = ReductionMode::Nondeterministic;
168+
}
169+
153170
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
154-
epilogue_args);
171+
epilogue_args, scheduler);
155172
}
156173

157174
template <typename OutType>
@@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
160177
torch::Tensor const& b,
161178
torch::Tensor const& a_scales,
162179
torch::Tensor const& b_scales) {
163-
cutlass_gemm_caller_blockwise<
164-
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
165-
b_scales);
180+
auto k = a.size(1);
181+
auto n = b.size(1);
182+
183+
if (k > 3 * n) {
184+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
185+
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
186+
out, a, b, a_scales, b_scales);
187+
} else {
188+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
189+
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
190+
out, a, b, a_scales, b_scales);
191+
}
166192
}
167193

168194
} // namespace vllm

0 commit comments

Comments
 (0)