@@ -15,12 +15,14 @@ limitations under the License. */
1515#include < cublas_v2.h>
1616#include < cuda_runtime_api.h>
1717#include < sys/time.h>
18+
1819#include < algorithm>
1920#include < fstream>
2021#include < iostream>
2122#include < limits>
2223#include < list>
2324#include < vector>
25+
2426#include " helper.h"
2527
2628template <typename T>
@@ -172,7 +174,7 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
172174}
173175
174176template <typename InT, typename OutT, typename ScaleT = OutT>
175- void FindAlgo (cublasLtHandle_t ltHandle,
177+ void FindAlgo (const cublasLtHandle_t& ltHandle,
176178 int m,
177179 int n,
178180 int k,
@@ -466,15 +468,14 @@ class DevContext {};
466468class CPUContext : public DevContext {};
467469
468470class CUBLASLTContext : public DevContext {
469- public:
470- CUBLASLTContext () { CUDA_CHECK (cublasLtCreate (&handle_ )); }
471+ public:
472+ CUBLASLTContext () { CUDA_CHECK (cublasLtCreate (&handle )); }
471473
472- private:
473- cublasLtHandle_t handle_;
474+ cublasLtHandle_t handle;
474475};
475476
476477template <typename InT, typename OutT, typename DevContext>
477- void GEMMInt8 (DevContext dev_ctx,
478+ void GEMMInt8 (const DevContext& dev_ctx,
478479 const std::vector<InT>& A,
479480 const std::vector<InT>& B,
480481 std::vector<OutT>& C,
@@ -488,7 +489,7 @@ void GEMMInt8(DevContext dev_ctx,
488489}
489490
490491template <>
491- void GEMMInt8<int8_t , int32_t , CPUContext>(CPUContext dev_ctx,
492+ void GEMMInt8<int8_t , int32_t , CPUContext>(const CPUContext& dev_ctx,
492493 const std::vector<int8_t >& A,
493494 const std::vector<int8_t >& B,
494495 std::vector<int32_t >& C,
@@ -502,7 +503,7 @@ void GEMMInt8<int8_t, int32_t, CPUContext>(CPUContext dev_ctx,
502503}
503504
504505template <>
505- void GEMMInt8<int8_t , int32_t , CUBLASLTContext>(CUBLASLTContext dev_ctx,
506+ void GEMMInt8<int8_t , int32_t , CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
506507 const std::vector<int8_t >& AVec,
507508 const std::vector<int8_t >& BVec,
508509 std::vector<int32_t >& CVec,
@@ -528,24 +529,24 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
528529
529530 // init data structure
530531
531- cublasLtMatmulDesc_t matmul_desc_ ;
532- cublasLtMatrixLayout_t A_desc_ ;
533- cublasLtMatrixLayout_t B_desc_ ;
534- cublasLtMatrixLayout_t C_desc_ ;
535- int32_t alpha_ = 1 ;
536- int32_t beta_ = 0 ;
532+ cublasLtMatmulDesc_t matmul_desc ;
533+ cublasLtMatrixLayout_t A_desc ;
534+ cublasLtMatrixLayout_t B_desc ;
535+ cublasLtMatrixLayout_t C_desc ;
536+ int32_t alpha = 1 ;
537+ int32_t beta = 0 ;
537538
538539 cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
539540 CUDA_CHECK (
540- cublasLtMatmulDescCreate (&matmul_desc_ , cudaComputeType, CUDA_R_32I));
541+ cublasLtMatmulDescCreate (&matmul_desc , cudaComputeType, CUDA_R_32I));
541542 cublasOperation_t op_transpose = CUBLAS_OP_T;
542- CUDA_CHECK (cublasLtMatmulDescSetAttribute (matmul_desc_ ,
543+ CUDA_CHECK (cublasLtMatmulDescSetAttribute (matmul_desc ,
543544 CUBLASLT_MATMUL_DESC_TRANSA,
544545 &op_transpose,
545546 sizeof (op_transpose)));
546- CUDA_CHECK (cublasLtMatrixLayoutCreate (&B_desc_ , CUDA_R_8I, k, n, k));
547- CUDA_CHECK (cublasLtMatrixLayoutCreate (&A_desc_ , CUDA_R_8I, k, m, k));
548- CUDA_CHECK (cublasLtMatrixLayoutCreate (&C_desc_ , CUDA_R_32I, n, m, n));
547+ CUDA_CHECK (cublasLtMatrixLayoutCreate (&B_desc , CUDA_R_8I, k, n, k));
548+ CUDA_CHECK (cublasLtMatrixLayoutCreate (&A_desc , CUDA_R_8I, k, m, k));
549+ CUDA_CHECK (cublasLtMatrixLayoutCreate (&C_desc , CUDA_R_32I, n, m, n));
549550
550551 cublasLtMatmulAlgo_t algo;
551552 int algoId;
@@ -574,17 +575,17 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
574575 if (is_test) {
575576 std::vector<algoSelect_t> algos;
576577 // Select //
577- FindAlgo (dev_ctx.handle_ ,
578+ FindAlgo (dev_ctx.handle ,
578579 m,
579580 n,
580581 k,
581582 B_dev,
582583 A_dev,
583584 C_dev,
584- matmul_desc_ ,
585- B_desc_ ,
586- A_desc_ ,
587- C_desc_ ,
585+ matmul_desc ,
586+ B_desc ,
587+ A_desc ,
588+ C_desc ,
588589 CUBLAS_COMPUTE_32I,
589590 CUDA_R_32I,
590591 CUDA_R_8I,
@@ -643,7 +644,7 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
643644 paddle::DataType::UINT8,
644645 paddle::GPUPlace ());
645646 void * workspace_ptr = workspace.data <uint8_t >();
646- CUDA_CHECK (cublasLtMatmulAlgoInit (dev_ctx.handle_ ,
647+ CUDA_CHECK (cublasLtMatmulAlgoInit (dev_ctx.handle ,
647648 cudaComputeType,
648649 CUDA_R_32I,
649650 CUDA_R_8I,
@@ -677,18 +678,18 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
677678 auto start = std::chrono::high_resolution_clock::now ();
678679 const int repeats = 10 ;
679680 for (int loop = 0 ; loop < repeats; loop++) {
680- CUDA_CHECK (cublasLtMatmul (dev_ctx.handle_ ,
681- matmul_desc_ ,
682- &alpha_ ,
681+ CUDA_CHECK (cublasLtMatmul (dev_ctx.handle ,
682+ matmul_desc ,
683+ &alpha ,
683684 B_dev,
684- B_desc_ ,
685+ B_desc ,
685686 A_dev,
686- A_desc_ ,
687- &beta_ ,
687+ A_desc ,
688+ &beta ,
688689 C_dev,
689- C_desc_ ,
690+ C_desc ,
690691 C_dev,
691- C_desc_ ,
692+ C_desc ,
692693 &algo,
693694 // nullptr,
694695 workspace_ptr,
@@ -711,8 +712,8 @@ void TuneCublasltGemm(const paddle::Tensor& M,
711712 bool is_test,
712713 bool is_read_from_file,
713714 const std::string& path) {
714-
715- // Ensure that M, K, and N are all one-dimensional Tensors. is_test != is_read_from_file
715+ // Ensure that M, K, and N are all one-dimensional Tensors. is_test !=
716+ // is_read_from_file
716717 assert (M.dims ().size () == 1 && K.dims ().size () == 1 && N.dims ().size () == 1 );
717718 assert (is_test != is_read_from_file);
718719
@@ -730,22 +731,34 @@ void TuneCublasltGemm(const paddle::Tensor& M,
730731
731732 int m_data = (int )M_data[0 ];
732733 assert (m_data > 0 && 4 <= 8192 );
733-
734+
734735 std::vector<int > mm;
735736
736737 int m = 1 , step = 1 ;
737- while (m <= m_data) {
738+ while (m <= m_data) {
738739 mm.push_back (m);
739740 m += step;
740741
741742 // update step
742743 switch (m) {
743- case 4 : step = 4 ; break ;
744- case 16 : step = 16 ; break ;
745- case 64 : step = 32 ; break ;
746- case 256 : step = 64 ; break ;
747- case 512 : step = 128 ; break ;
748- case 1024 : step = 1024 ; break ;
744+ case 4 :
745+ step = 4 ;
746+ break ;
747+ case 16 :
748+ step = 16 ;
749+ break ;
750+ case 64 :
751+ step = 32 ;
752+ break ;
753+ case 256 :
754+ step = 64 ;
755+ break ;
756+ case 512 :
757+ step = 128 ;
758+ break ;
759+ case 1024 :
760+ step = 1024 ;
761+ break ;
749762 }
750763 }
751764
@@ -761,15 +774,15 @@ void TuneCublasltGemm(const paddle::Tensor& M,
761774 if (dtype == " int8" ) {
762775 CUBLASLTContext dev_ctx;
763776 GEMMInt8 (dev_ctx,
764- A,
765- B,
766- C,
767- m,
768- k,
769- n,
770- is_test, /* is_test*/
771- is_read_from_file, /* is_read_from_file*/
772- path);
777+ A,
778+ B,
779+ C,
780+ m,
781+ k,
782+ n,
783+ is_test, /* is_test*/
784+ is_read_from_file, /* is_read_from_file*/
785+ path);
773786 } else {
774787 // other dtype
775788 std::cout << " Not currently supported" << std::endl;
0 commit comments