@@ -22,8 +22,9 @@ namespace vllm {
22
22
23
23
using namespace cute ;
24
24
25
- template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
26
- int TileSizeM_ = 128 , class ClusterShape = Shape<_1, _2, _1>>
25
+ template <typename SchedulerType, typename OutType, int GroupSizeM_,
26
+ int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128 ,
27
+ class ClusterShape = Shape<_1, _2, _1>>
27
28
struct cutlass_3x_gemm_fp8_blockwise {
28
29
using GroupSizeM = Int<GroupSizeM_>;
29
30
using GroupSizeN = Int<GroupSizeN_>;
@@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
84
85
85
86
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
86
87
Shape<int , int , int , int >, CollectiveMainloop, CollectiveEpilogue,
87
- cutlass::gemm::PersistentScheduler >>;
88
+ SchedulerType >>;
88
89
89
90
struct GemmKernel : public KernelType {};
90
91
@@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
150
151
typename GemmKernel::EpilogueArguments epilogue_args{
151
152
{}, c_ptr, c_stride, c_ptr, c_stride};
152
153
154
+ typename GemmKernel::TileSchedulerArguments scheduler;
155
+
156
+ static constexpr bool UsesStreamKScheduler =
157
+ cute::is_same_v<typename GemmKernel::TileSchedulerTag,
158
+ cutlass::gemm::StreamKScheduler>;
159
+
160
+ if constexpr (UsesStreamKScheduler) {
161
+ using DecompositionMode = typename cutlass::gemm::kernel::detail::
162
+ PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
163
+ using ReductionMode = typename cutlass::gemm::kernel::detail::
164
+ PersistentTileSchedulerSm90StreamKParams::ReductionMode;
165
+
166
+ scheduler.decomposition_mode = DecompositionMode::StreamK;
167
+ scheduler.reduction_mode = ReductionMode::Nondeterministic;
168
+ }
169
+
153
170
c3x::cutlass_gemm_caller<GemmKernel>(a.device (), prob_shape, mainloop_args,
154
- epilogue_args);
171
+ epilogue_args, scheduler );
155
172
}
156
173
157
174
template <typename OutType>
@@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
160
177
torch::Tensor const & b,
161
178
torch::Tensor const & a_scales,
162
179
torch::Tensor const & b_scales) {
163
- cutlass_gemm_caller_blockwise<
164
- cutlass_3x_gemm_fp8_blockwise<OutType, 1 , 128 , 128 >>(out, a, b, a_scales,
165
- b_scales);
180
+ auto k = a.size (1 );
181
+ auto n = b.size (1 );
182
+
183
+ if (k > 3 * n) {
184
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
185
+ cutlass::gemm::StreamKScheduler, OutType, 1 , 128 , 128 >>(
186
+ out, a, b, a_scales, b_scales);
187
+ } else {
188
+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
189
+ cutlass::gemm::PersistentScheduler, OutType, 1 , 128 , 128 >>(
190
+ out, a, b, a_scales, b_scales);
191
+ }
166
192
}
167
193
168
194
} // namespace vllm
0 commit comments