@@ -29,9 +29,11 @@ See the License for the specific language governing permissions and
29
29
limitations under the License. */
30
30
31
31
#include " paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
32
+ #include < optional>
32
33
#include " paddle/common/errors.h"
33
34
#include " paddle/phi/core/enforce.h"
34
35
#include " paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/arch_define.h"
36
+ #include " paddle/phi/kernels/fusion/cutlass/cutlass_kernels/gemm_config_manager.h"
35
37
#pragma GCC diagnostic push
36
38
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
37
39
#pragma GCC diagnostic pop
@@ -285,6 +287,29 @@ void dispatch_gemm_to_cutlass(const T* A,
285
287
stream,
286
288
occupancy);
287
289
break ;
290
+ case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
291
+ dispatch_gemm_config<T,
292
+ WeightType,
293
+ arch,
294
+ EpilogueTag,
295
+ FineGrained,
296
+ cutlass::gemm::GemmShape<128 , 128 , 64 >,
297
+ cutlass::gemm::GemmShape<128 , 32 , 64 >>(
298
+ A,
299
+ B,
300
+ weight_scales,
301
+ biases,
302
+ C,
303
+ m,
304
+ n,
305
+ k,
306
+ group_size,
307
+ gemm_config,
308
+ workspace,
309
+ workspace_bytes,
310
+ stream,
311
+ occupancy);
312
+ break ;
288
313
// config for M_16000_N_12288_K_6144 in encoder
289
314
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
290
315
dispatch_gemm_config<T,
@@ -573,41 +598,92 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag,
573
598
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
574
599
const bool is_weight_only_encoder = m >= 512 ? true : false ;
575
600
std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs (
576
- sm_, group_size, is_weight_only, is_weight_only_encoder, false );
577
- std::vector<int > occupancies (candidate_configs.size ());
601
+ sm_, group_size, is_weight_only, is_weight_only_encoder, false , false );
578
602
579
- for (size_t ii = 0 ; ii < candidate_configs.size (); ++ii) {
580
- dispatch_to_arch<EpilogueTag, FineGrained>(A,
581
- B,
582
- weight_scales,
583
- biases,
584
- C,
585
- m,
586
- n,
587
- k,
588
- group_size,
589
- candidate_configs[ii],
590
- workspace_ptr,
591
- workspace_bytes,
592
- stream,
593
- &occupancies[ii]);
594
- }
595
603
// Standard GEMM, so 1 "expert". We use the same function for MoE and regular
596
604
// FFN.
597
605
static constexpr int num_experts = 1 ;
598
- CutlassGemmConfig chosen_config =
599
- estimate_best_config_from_occupancies (candidate_configs,
600
- occupancies,
601
- m,
602
- n,
603
- k,
604
- group_size,
605
- num_experts,
606
- split_k_limit,
607
- workspace_bytes,
608
- multi_processor_count_,
609
- is_weight_only,
610
- sm_);
606
+ static constexpr int warm_time = 5 ;
607
+ static constexpr int test_time = 10 ;
608
+
609
+ auto & gemmConfigManager = phi::GemmConfigManager::Instance ();
610
+ constexpr GemmDataType dtype = getGemmDataType<T>();
611
+ constexpr GemmDataType wdtype = getGemmDataType<WeightType>();
612
+ GemmIDType gemmId{n, k, GemmType::FPAINTBGEMM, dtype, wdtype, num_experts};
613
+ CutlassGemmConfig chosen_config;
614
+ auto chosen_config_optional = gemmConfigManager.getBestConfig (gemmId, m);
615
+ if (chosen_config_optional != std::nullopt) {
616
+ chosen_config = chosen_config_optional.value ();
617
+ } else {
618
+ float best_time = std::numeric_limits<float >::max ();
619
+ CutlassGemmConfig best_config;
620
+ int profile_m = gemmConfigManager.nextPowerOfTwo (m);
621
+ bool found_one = false ;
622
+
623
+ for (size_t ii = 0 ; ii < candidate_configs.size (); ++ii) {
624
+ for (int i = 0 ; i < warm_time; i++) {
625
+ dispatch_to_arch<EpilogueTag, FineGrained>(A,
626
+ B,
627
+ weight_scales,
628
+ biases,
629
+ C,
630
+ m,
631
+ n,
632
+ k,
633
+ group_size,
634
+ candidate_configs[ii],
635
+ workspace_ptr,
636
+ workspace_bytes,
637
+ stream);
638
+ }
639
+ cudaEvent_t start;
640
+ cudaEvent_t stop;
641
+ check_cuda_error (cudaEventCreate (&start));
642
+ check_cuda_error (cudaEventCreate (&stop));
643
+ check_cuda_error (cudaStreamSynchronize (stream));
644
+ check_cuda_error (cudaEventRecord (start, stream));
645
+ for (int i = 0 ; i < test_time; i++) {
646
+ dispatch_to_arch<EpilogueTag, FineGrained>(A,
647
+ B,
648
+ weight_scales,
649
+ biases,
650
+ C,
651
+ m,
652
+ n,
653
+ k,
654
+ group_size,
655
+ candidate_configs[ii],
656
+ workspace_ptr,
657
+ workspace_bytes,
658
+ stream);
659
+ }
660
+ check_cuda_error (cudaEventRecord (stop, stream));
661
+ check_cuda_error (cudaEventSynchronize (stop));
662
+ found_one = true ;
663
+ float elapsed;
664
+ check_cuda_error (cudaEventElapsedTime (&elapsed, start, stop));
665
+ check_cuda_error (cudaEventDestroy (start));
666
+ check_cuda_error (cudaEventDestroy (stop));
667
+ if (elapsed < best_time) {
668
+ best_time = elapsed;
669
+ best_config = candidate_configs[ii];
670
+ }
671
+ VLOG (4 ) << " profile_m" << profile_m;
672
+ VLOG (4 ) << " candidate_config tile_config"
673
+ << static_cast <int >(candidate_configs[ii].tile_config );
674
+ VLOG (4 ) << " candidate_config split_k_style"
675
+ << static_cast <int >(candidate_configs[ii].split_k_style );
676
+ VLOG (4 ) << " candidate_config split_k_factor "
677
+ << candidate_configs[ii].split_k_factor ;
678
+ VLOG (4 ) << " candidate_config stages " << candidate_configs[ii].stages ;
679
+ VLOG (4 ) << " elapsed time: " << elapsed;
680
+ VLOG (4 ) << " best_time: " << best_time;
681
+ }
682
+ if (found_one) {
683
+ gemmConfigManager.addBestConfig (gemmId, profile_m, best_config);
684
+ chosen_config = best_config;
685
+ }
686
+ }
611
687
612
688
dispatch_to_arch<EpilogueTag, FineGrained>(A,
613
689
B,
0 commit comments