@@ -18,11 +18,11 @@ limitations under the License. */
1818
1919#include < algorithm>
2020#include < fstream>
21+ #include < iomanip>
2122#include < iostream>
2223#include < limits>
2324#include < list>
2425#include < vector>
25- #include < iomanip>
2626
2727#include " helper.h"
2828
@@ -105,6 +105,13 @@ static inline bool time_compare_algo_para(const algoSelect_t& algo_para_a,
105105 return (algo_para_a.time < algo_para_b.time );
106106}
107107
108+ // 获取当前 GPU 的剩余显存大小(以字节为单位)
109+ size_t get_remaining_memory () {
110+ size_t free , total;
111+ CUDA_CHECK (cudaMemGetInfo (&free , &total));
112+ return free ;
113+ }
114+
108115template <typename InT, typename OutT, typename ScaleT = OutT>
109116static void TestMatmulRun (cublasLtHandle_t ltHandle,
110117 cublasLtMatmulDesc_t matmulDesc,
@@ -122,7 +129,10 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
122129 cublasLtMatmulHeuristicResult_t heurResult;
123130 cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck (
124131 ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult);
125- if (algoStatus == CUBLAS_STATUS_SUCCESS) {
132+
133+ auto remainingMemorySize = 0.95 * get_remaining_memory ();
134+ if (algoStatus == CUBLAS_STATUS_SUCCESS &&
135+ remainingMemorySize > heurResult.workspaceSize ) {
126136 ScaleT alpha = static_cast <ScaleT>(1 ), beta = static_cast <ScaleT>(0 );
127137 void * workSpace;
128138 CUDA_CHECK (cudaMalloc (&workSpace, heurResult.workspaceSize ));
@@ -166,8 +176,13 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
166176 }
167177 CUDA_CHECK (cudaFree (workSpace));
168178 } else {
169- std::cerr << " not enough workspace! current workspace is "
170- << heurResult.workspaceSize ;
179+ std::cerr << " Not enough workspace! Required "
180+ << static_cast <double >(heurResult.workspaceSize ) / 1024.0 /
181+ 1024.0 / 1024.0
182+ << " GiB" << " , But remaining "
183+ << static_cast <double >(remainingMemorySize) / 1024.0 / 1024.0 /
184+ 1024.0
185+ << " GiB" << std::endl;
171186 perfResults.status = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
172187 }
173188}
@@ -442,7 +457,7 @@ void FindAlgo(const cublasLtHandle_t& ltHandle,
442457 if (perfResults[i].status != CUBLAS_STATUS_SUCCESS) {
443458 std::clog << " algo " << algos[i].algoId << " tile " << algos[i].tile
444459 << " stages " << algos[i].stages << " splitK_val "
445- << algos[i].splitK_val ;
460+ << algos[i].splitK_val << std::endl ;
446461 algos[i].time = std::numeric_limits<float >::max ();
447462 std::cerr << " TestMatmulRun with status " << perfResults[i].status
448463 << std::endl;
@@ -467,7 +482,7 @@ class DevContext {};
467482class CPUContext : public DevContext {};
468483
469484class CUBLASLTContext : public DevContext {
470- public:
485+ public:
471486 CUBLASLTContext () { CUDA_CHECK (cublasLtCreate (&handle)); }
472487
473488 cublasLtHandle_t handle;
@@ -709,64 +724,51 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
709724 CUDA_CHECK (cudaFree (workSpace));
710725}
711726
712- void TuneCublasltGemm (const paddle::Tensor& M,
713- const paddle::Tensor& K,
727+ void TuneCublasltGemm (const paddle::Tensor& K,
714728 const paddle::Tensor& N,
729+ const int M_start,
730+ const int M_end,
715731 const std::string& dtype,
716- bool is_test,
717- bool is_read_from_file,
732+ const bool is_test,
733+ const bool is_read_from_file,
718734 const std::string& path) {
719- // Ensure that M, K, and N are all one-dimensional Tensors. is_test !=
720- // is_read_from_file
721- assert (M. dims (). size () == 1 && K.dims ().size () == 1 && N.dims ().size () == 1 );
735+ assert (M_end >= M_start);
736+ assert (M_start >= 1 );
737+ assert (K.dims ().size () == 1 && N.dims ().size () == 1 );
722738 assert (is_test != is_read_from_file);
723739
724- auto M_cpu = M.copy_to (paddle::CPUPlace (), false );
725740 auto K_cpu = K.copy_to (paddle::CPUPlace (), false );
726741 auto N_cpu = N.copy_to (paddle::CPUPlace (), false );
727- int64_t * M_data = M_cpu.data <int64_t >();
728742 int64_t * K_data = K_cpu.data <int64_t >();
729743 int64_t * N_data = N_cpu.data <int64_t >();
730744
731- int M_size = M.numel ();
732745 int K_size = K.numel ();
733746 int N_size = N.numel ();
734747 assert (K_size == N_size);
735748
736- int m_data = (int )M_data[0 ];
737- assert (m_data > 0 );
738-
739749 std::vector<int > mm;
740-
741- int m = 1 , step = 1 ;
742- while (m <= m_data) {
743- mm.push_back (m);
744- m += step;
745-
750+ int m = M_start, step = 1 ;
751+ while (m <= M_end) {
746752 // update step
747- switch (m) {
748- case 4 :
749- step = 4 ;
750- break ;
751- case 16 :
752- step = 16 ;
753- break ;
754- case 64 :
755- step = 32 ;
756- break ;
757- case 256 :
758- step = 64 ;
759- break ;
760- case 512 :
761- step = 128 ;
762- break ;
763- case 1024 :
764- step = 1024 ;
765- break ;
766- case 8192 :
767- step = 4096 ;
768- break ;
753+ if (m >= 8192 ) {
754+ step = 4096 ;
755+ } else if (m >= 1024 ) {
756+ step = 1024 ;
757+ } else if (m >= 512 ) {
758+ step = 128 ;
759+ } else if (m >= 256 ) {
760+ step = 64 ;
761+ } else if (m >= 64 ) {
762+ step = 32 ;
763+ } else if (m >= 16 ) {
764+ step = 16 ;
765+ } else if (m >= 4 ) {
766+ step = 4 ;
767+ } else {
768+ step = 1 ;
769769 }
770+ mm.push_back (m);
771+ m += step;
770772 }
771773
772774 for (int j = 0 ; j < mm.size (); j++) {
@@ -792,16 +794,18 @@ void TuneCublasltGemm(const paddle::Tensor& M,
792794 path);
793795 } else {
794796 // other dtype
795- std::cout << " Not currently supported" << std::endl ;
797+ throw std::runtime_error (dtype + " not currently supported" ) ;
796798 }
797799 }
798800 }
799801}
800802
801803PD_BUILD_OP (tune_cublaslt_gemm)
802- .Inputs({" M " , " K" , " N" })
804+ .Inputs({" K" , " N" })
803805 .Outputs({})
804- .Attrs({" dtype: std::string" ,
806+ .Attrs({" M_start: int" ,
807+ " M_end: int" ,
808+ " dtype: std::string" ,
805809 " is_test: bool" ,
806810 " is_read_from_file: bool" ,
807811 " path: std::string" })
0 commit comments