Skip to content

Commit a59c897

Browse files
committed
fix
1 parent f24b4dd commit a59c897

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

csrc/generation/tune_cublaslt_gemm.cu

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
174174
}
175175

176176
template <typename InT, typename OutT, typename ScaleT = OutT>
177-
void FindAlgo(cublasLtHandle_t ltHandle,
177+
void FindAlgo(const cublasLtHandle_t& ltHandle,
178178
int m,
179179
int n,
180180
int k,
@@ -475,7 +475,7 @@ public:
475475
};
476476

477477
template <typename InT, typename OutT, typename DevContext>
478-
void GEMMInt8(DevContext dev_ctx,
478+
void GEMMInt8(const DevContext& dev_ctx,
479479
const std::vector<InT>& A,
480480
const std::vector<InT>& B,
481481
std::vector<OutT>& C,
@@ -489,7 +489,7 @@ void GEMMInt8(DevContext dev_ctx,
489489
}
490490

491491
template <>
492-
void GEMMInt8<int8_t, int32_t, CPUContext>(CPUContext dev_ctx,
492+
void GEMMInt8<int8_t, int32_t, CPUContext>(const CPUContext& dev_ctx,
493493
const std::vector<int8_t>& A,
494494
const std::vector<int8_t>& B,
495495
std::vector<int32_t>& C,
@@ -529,24 +529,24 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
529529

530530
// init data structure
531531

532-
cublasLtMatmulDesc_t matmul_desc_;
533-
cublasLtMatrixLayout_t A_desc_;
534-
cublasLtMatrixLayout_t B_desc_;
535-
cublasLtMatrixLayout_t C_desc_;
536-
int32_t alpha_ = 1;
537-
int32_t beta_ = 0;
532+
cublasLtMatmulDesc_t matmul_desc;
533+
cublasLtMatrixLayout_t A_desc;
534+
cublasLtMatrixLayout_t B_desc;
535+
cublasLtMatrixLayout_t C_desc;
536+
int32_t alpha = 1;
537+
int32_t beta = 0;
538538

539539
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
540540
CUDA_CHECK(
541-
cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType, CUDA_R_32I));
541+
cublasLtMatmulDescCreate(&matmul_desc, cudaComputeType, CUDA_R_32I));
542542
cublasOperation_t op_transpose = CUBLAS_OP_T;
543-
CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc_,
543+
CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc,
544544
CUBLASLT_MATMUL_DESC_TRANSA,
545545
&op_transpose,
546546
sizeof(op_transpose)));
547-
CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k));
548-
CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k));
549-
CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n));
547+
CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc, CUDA_R_8I, k, n, k));
548+
CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, k, m, k));
549+
CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc, CUDA_R_32I, n, m, n));
550550

551551
cublasLtMatmulAlgo_t algo;
552552
int algoId;
@@ -582,10 +582,10 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
582582
B_dev,
583583
A_dev,
584584
C_dev,
585-
matmul_desc_,
586-
B_desc_,
587-
A_desc_,
588-
C_desc_,
585+
matmul_desc,
586+
B_desc,
587+
A_desc,
588+
C_desc,
589589
CUBLAS_COMPUTE_32I,
590590
CUDA_R_32I,
591591
CUDA_R_8I,
@@ -679,17 +679,17 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
679679
const int repeats = 10;
680680
for (int loop = 0; loop < repeats; loop++) {
681681
CUDA_CHECK(cublasLtMatmul(dev_ctx.handle,
682-
matmul_desc_,
683-
&alpha_,
682+
matmul_desc,
683+
&alpha,
684684
B_dev,
685-
B_desc_,
685+
B_desc,
686686
A_dev,
687-
A_desc_,
688-
&beta_,
687+
A_desc,
688+
&beta,
689689
C_dev,
690-
C_desc_,
690+
C_desc,
691691
C_dev,
692-
C_desc_,
692+
C_desc,
693693
&algo,
694694
// nullptr,
695695
workspace_ptr,

0 commit comments

Comments
 (0)