@@ -174,7 +174,7 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
174174}
175175
176176template <typename InT, typename OutT, typename ScaleT = OutT>
177- void FindAlgo (cublasLtHandle_t ltHandle,
177+ void FindAlgo (const cublasLtHandle_t& ltHandle,
178178 int m,
179179 int n,
180180 int k,
@@ -475,7 +475,7 @@ public:
475475};
476476
477477template <typename InT, typename OutT, typename DevContext>
478- void GEMMInt8 (DevContext dev_ctx,
478+ void GEMMInt8 (const DevContext& dev_ctx,
479479 const std::vector<InT>& A,
480480 const std::vector<InT>& B,
481481 std::vector<OutT>& C,
@@ -489,7 +489,7 @@ void GEMMInt8(DevContext dev_ctx,
489489}
490490
491491template <>
492- void GEMMInt8<int8_t , int32_t , CPUContext>(CPUContext dev_ctx,
492+ void GEMMInt8<int8_t , int32_t , CPUContext>(const CPUContext& dev_ctx,
493493 const std::vector<int8_t >& A,
494494 const std::vector<int8_t >& B,
495495 std::vector<int32_t >& C,
@@ -529,24 +529,24 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
529529
530530 // init data structure
531531
532- cublasLtMatmulDesc_t matmul_desc_ ;
533- cublasLtMatrixLayout_t A_desc_ ;
534- cublasLtMatrixLayout_t B_desc_ ;
535- cublasLtMatrixLayout_t C_desc_ ;
536- int32_t alpha_ = 1 ;
537- int32_t beta_ = 0 ;
532+ cublasLtMatmulDesc_t matmul_desc ;
533+ cublasLtMatrixLayout_t A_desc ;
534+ cublasLtMatrixLayout_t B_desc ;
535+ cublasLtMatrixLayout_t C_desc ;
536+ int32_t alpha = 1 ;
537+ int32_t beta = 0 ;
538538
539539 cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
540540 CUDA_CHECK (
541- cublasLtMatmulDescCreate (&matmul_desc_ , cudaComputeType, CUDA_R_32I));
541+ cublasLtMatmulDescCreate (&matmul_desc , cudaComputeType, CUDA_R_32I));
542542 cublasOperation_t op_transpose = CUBLAS_OP_T;
543- CUDA_CHECK (cublasLtMatmulDescSetAttribute (matmul_desc_ ,
543+ CUDA_CHECK (cublasLtMatmulDescSetAttribute (matmul_desc ,
544544 CUBLASLT_MATMUL_DESC_TRANSA,
545545 &op_transpose,
546546 sizeof (op_transpose)));
547- CUDA_CHECK (cublasLtMatrixLayoutCreate (&B_desc_ , CUDA_R_8I, k, n, k));
548- CUDA_CHECK (cublasLtMatrixLayoutCreate (&A_desc_ , CUDA_R_8I, k, m, k));
549- CUDA_CHECK (cublasLtMatrixLayoutCreate (&C_desc_ , CUDA_R_32I, n, m, n));
547+ CUDA_CHECK (cublasLtMatrixLayoutCreate (&B_desc , CUDA_R_8I, k, n, k));
548+ CUDA_CHECK (cublasLtMatrixLayoutCreate (&A_desc , CUDA_R_8I, k, m, k));
549+ CUDA_CHECK (cublasLtMatrixLayoutCreate (&C_desc , CUDA_R_32I, n, m, n));
550550
551551 cublasLtMatmulAlgo_t algo;
552552 int algoId;
@@ -582,10 +582,10 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
582582 B_dev,
583583 A_dev,
584584 C_dev,
585- matmul_desc_ ,
586- B_desc_ ,
587- A_desc_ ,
588- C_desc_ ,
585+ matmul_desc ,
586+ B_desc ,
587+ A_desc ,
588+ C_desc ,
589589 CUBLAS_COMPUTE_32I,
590590 CUDA_R_32I,
591591 CUDA_R_8I,
@@ -679,17 +679,17 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
679679 const int repeats = 10 ;
680680 for (int loop = 0 ; loop < repeats; loop++) {
681681 CUDA_CHECK (cublasLtMatmul (dev_ctx.handle ,
682- matmul_desc_ ,
683- &alpha_ ,
682+ matmul_desc ,
683+ &alpha ,
684684 B_dev,
685- B_desc_ ,
685+ B_desc ,
686686 A_dev,
687- A_desc_ ,
688- &beta_ ,
687+ A_desc ,
688+ &beta ,
689689 C_dev,
690- C_desc_ ,
690+ C_desc ,
691691 C_dev,
692- C_desc_ ,
692+ C_desc ,
693693 &algo,
694694 // nullptr,
695695 workspace_ptr,
0 commit comments