Skip to content

Commit fc4ff25

Browse files
authored
[Inference] Fused Moe Optimization (#70059)
* add gemm_config_manager * add serialize & deserialize to support get profile from json
1 parent e209215 commit fc4ff25

File tree

5 files changed

+494
-68
lines changed

5 files changed

+494
-68
lines changed

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,16 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
109109
const bool is_weight_only_encoder,
110110
const bool simt_configs_only,
111111
const int sm,
112-
const int group_size) {
112+
const int group_size,
113+
const bool is_moe) {
113114
VLOG(3) << "get_candidate_tiles sm: " << sm;
114115
std::vector<CutlassTileConfig> simt_configs{
115116
CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
116117

117118
std::vector<CutlassTileConfig> square_configs{
118119
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
119120
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
121+
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64,
120122
};
121123
std::vector<CutlassTileConfig> quant_B_configs_sm70{
122124
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
@@ -129,6 +131,13 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
129131
CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64,
130132
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
131133
};
134+
if (is_moe) {
135+
quant_B_configs_sm80.push_back(
136+
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64);
137+
} else {
138+
quant_B_configs_sm80.push_back(
139+
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64);
140+
}
132141
std::vector<CutlassTileConfig> quant_B_configs_sm80_finegrained{
133142
CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
134143
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
@@ -164,13 +173,15 @@ static std::vector<CutlassGemmConfig> get_candidate_configs(
164173
const int group_size,
165174
const bool is_weight_only,
166175
const bool is_weight_only_encoder,
167-
const bool simt_configs_only) {
176+
const bool simt_configs_only,
177+
const bool is_moe) {
168178
std::vector<CutlassTileConfig> tiles =
169179
get_candidate_tiles(is_weight_only,
170180
is_weight_only_encoder,
171181
simt_configs_only,
172182
sm,
173-
group_size);
183+
group_size,
184+
is_moe);
174185

175186
std::vector<CutlassGemmConfig> candidate_configs;
176187
const int min_stages = 2;

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ See the License for the specific language governing permissions and
2929
limitations under the License. */
3030

3131
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
32+
#include <optional>
3233
#include "paddle/common/errors.h"
3334
#include "paddle/phi/core/enforce.h"
3435
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/arch_define.h"
36+
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/gemm_config_manager.h"
3537
#pragma GCC diagnostic push
3638
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
3739
#pragma GCC diagnostic pop
@@ -285,6 +287,29 @@ void dispatch_gemm_to_cutlass(const T* A,
285287
stream,
286288
occupancy);
287289
break;
290+
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
291+
dispatch_gemm_config<T,
292+
WeightType,
293+
arch,
294+
EpilogueTag,
295+
FineGrained,
296+
cutlass::gemm::GemmShape<128, 128, 64>,
297+
cutlass::gemm::GemmShape<128, 32, 64>>(
298+
A,
299+
B,
300+
weight_scales,
301+
biases,
302+
C,
303+
m,
304+
n,
305+
k,
306+
group_size,
307+
gemm_config,
308+
workspace,
309+
workspace_bytes,
310+
stream,
311+
occupancy);
312+
break;
288313
// config for M_16000_N_12288_K_6144 in encoder
289314
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
290315
dispatch_gemm_config<T,
@@ -573,41 +598,92 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag,
573598
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
574599
const bool is_weight_only_encoder = m >= 512 ? true : false;
575600
std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs(
576-
sm_, group_size, is_weight_only, is_weight_only_encoder, false);
577-
std::vector<int> occupancies(candidate_configs.size());
601+
sm_, group_size, is_weight_only, is_weight_only_encoder, false, false);
578602

579-
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
580-
dispatch_to_arch<EpilogueTag, FineGrained>(A,
581-
B,
582-
weight_scales,
583-
biases,
584-
C,
585-
m,
586-
n,
587-
k,
588-
group_size,
589-
candidate_configs[ii],
590-
workspace_ptr,
591-
workspace_bytes,
592-
stream,
593-
&occupancies[ii]);
594-
}
595603
// Standard GEMM, so 1 "expert". We use the same function for MoE and regular
596604
// FFN.
597605
static constexpr int num_experts = 1;
598-
CutlassGemmConfig chosen_config =
599-
estimate_best_config_from_occupancies(candidate_configs,
600-
occupancies,
601-
m,
602-
n,
603-
k,
604-
group_size,
605-
num_experts,
606-
split_k_limit,
607-
workspace_bytes,
608-
multi_processor_count_,
609-
is_weight_only,
610-
sm_);
606+
static constexpr int warm_time = 5;
607+
static constexpr int test_time = 10;
608+
609+
auto& gemmConfigManager = phi::GemmConfigManager::Instance();
610+
constexpr GemmDataType dtype = getGemmDataType<T>();
611+
constexpr GemmDataType wdtype = getGemmDataType<WeightType>();
612+
GemmIDType gemmId{n, k, GemmType::FPAINTBGEMM, dtype, wdtype, num_experts};
613+
CutlassGemmConfig chosen_config;
614+
auto chosen_config_optional = gemmConfigManager.getBestConfig(gemmId, m);
615+
if (chosen_config_optional != std::nullopt) {
616+
chosen_config = chosen_config_optional.value();
617+
} else {
618+
float best_time = std::numeric_limits<float>::max();
619+
CutlassGemmConfig best_config;
620+
int profile_m = gemmConfigManager.nextPowerOfTwo(m);
621+
bool found_one = false;
622+
623+
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
624+
for (int i = 0; i < warm_time; i++) {
625+
dispatch_to_arch<EpilogueTag, FineGrained>(A,
626+
B,
627+
weight_scales,
628+
biases,
629+
C,
630+
m,
631+
n,
632+
k,
633+
group_size,
634+
candidate_configs[ii],
635+
workspace_ptr,
636+
workspace_bytes,
637+
stream);
638+
}
639+
cudaEvent_t start;
640+
cudaEvent_t stop;
641+
check_cuda_error(cudaEventCreate(&start));
642+
check_cuda_error(cudaEventCreate(&stop));
643+
check_cuda_error(cudaStreamSynchronize(stream));
644+
check_cuda_error(cudaEventRecord(start, stream));
645+
for (int i = 0; i < test_time; i++) {
646+
dispatch_to_arch<EpilogueTag, FineGrained>(A,
647+
B,
648+
weight_scales,
649+
biases,
650+
C,
651+
m,
652+
n,
653+
k,
654+
group_size,
655+
candidate_configs[ii],
656+
workspace_ptr,
657+
workspace_bytes,
658+
stream);
659+
}
660+
check_cuda_error(cudaEventRecord(stop, stream));
661+
check_cuda_error(cudaEventSynchronize(stop));
662+
found_one = true;
663+
float elapsed;
664+
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
665+
check_cuda_error(cudaEventDestroy(start));
666+
check_cuda_error(cudaEventDestroy(stop));
667+
if (elapsed < best_time) {
668+
best_time = elapsed;
669+
best_config = candidate_configs[ii];
670+
}
671+
VLOG(4) << "profile_m" << profile_m;
672+
VLOG(4) << "candidate_config tile_config"
673+
<< static_cast<int>(candidate_configs[ii].tile_config);
674+
VLOG(4) << "candidate_config split_k_style"
675+
<< static_cast<int>(candidate_configs[ii].split_k_style);
676+
VLOG(4) << "candidate_config split_k_factor "
677+
<< candidate_configs[ii].split_k_factor;
678+
VLOG(4) << "candidate_config stages " << candidate_configs[ii].stages;
679+
VLOG(4) << "elapsed time: " << elapsed;
680+
VLOG(4) << "best_time: " << best_time;
681+
}
682+
if (found_one) {
683+
gemmConfigManager.addBestConfig(gemmId, profile_m, best_config);
684+
chosen_config = best_config;
685+
}
686+
}
611687

612688
dispatch_to_arch<EpilogueTag, FineGrained>(A,
613689
B,

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"cutlass::gemm::GemmShape<32, 128, 64>",
9191
"cutlass::gemm::GemmShape<64, 128, 64>",
9292
"cutlass::gemm::GemmShape<128, 128, 64>",
93+
"cutlass::gemm::GemmShape<128, 128, 64>",
9394
"cutlass::gemm::GemmShape<128, 256, 64>",
9495
"cutlass::gemm::GemmShape<256, 128, 64>",
9596
]
@@ -98,6 +99,7 @@
9899
"cutlass::gemm::GemmShape<32, 32, 64>",
99100
"cutlass::gemm::GemmShape<64, 64, 64>",
100101
"cutlass::gemm::GemmShape<64, 64, 64>",
102+
"cutlass::gemm::GemmShape<128, 32, 64>",
101103
"cutlass::gemm::GemmShape<64, 64, 64>",
102104
"cutlass::gemm::GemmShape<64, 64, 64>",
103105
]

0 commit comments

Comments
 (0)