Skip to content

Commit af5cda2

Browse files
yuanlehomeDrownFish19
authored andcommitted
Fix tune_cublaslt_gemm compile bug (PaddlePaddle#8844)
* fix tune_cublaslt_gemm * Fix: * fix
1 parent 877302b commit af5cda2

File tree

1 file changed

+65
-52
lines changed

1 file changed

+65
-52
lines changed

csrc/generation/tune_cublaslt_gemm.cu

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ limitations under the License. */
1515
#include <cublas_v2.h>
1616
#include <cuda_runtime_api.h>
1717
#include <sys/time.h>
18+
1819
#include <algorithm>
1920
#include <fstream>
2021
#include <iostream>
2122
#include <limits>
2223
#include <list>
2324
#include <vector>
25+
2426
#include "helper.h"
2527

2628
template <typename T>
@@ -172,7 +174,7 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
172174
}
173175

174176
template <typename InT, typename OutT, typename ScaleT = OutT>
175-
void FindAlgo(cublasLtHandle_t ltHandle,
177+
void FindAlgo(const cublasLtHandle_t& ltHandle,
176178
int m,
177179
int n,
178180
int k,
@@ -466,15 +468,14 @@ class DevContext {};
466468
class CPUContext : public DevContext {};
467469

468470
class CUBLASLTContext : public DevContext {
469-
public:
470-
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle_)); }
471+
public:
472+
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); }
471473

472-
private:
473-
cublasLtHandle_t handle_;
474+
cublasLtHandle_t handle;
474475
};
475476

476477
template <typename InT, typename OutT, typename DevContext>
477-
void GEMMInt8(DevContext dev_ctx,
478+
void GEMMInt8(const DevContext& dev_ctx,
478479
const std::vector<InT>& A,
479480
const std::vector<InT>& B,
480481
std::vector<OutT>& C,
@@ -488,7 +489,7 @@ void GEMMInt8(DevContext dev_ctx,
488489
}
489490

490491
template <>
491-
void GEMMInt8<int8_t, int32_t, CPUContext>(CPUContext dev_ctx,
492+
void GEMMInt8<int8_t, int32_t, CPUContext>(const CPUContext& dev_ctx,
492493
const std::vector<int8_t>& A,
493494
const std::vector<int8_t>& B,
494495
std::vector<int32_t>& C,
@@ -502,7 +503,7 @@ void GEMMInt8<int8_t, int32_t, CPUContext>(CPUContext dev_ctx,
502503
}
503504

504505
template <>
505-
void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
506+
void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
506507
const std::vector<int8_t>& AVec,
507508
const std::vector<int8_t>& BVec,
508509
std::vector<int32_t>& CVec,
@@ -528,24 +529,24 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
528529

529530
// init data structure
530531

531-
cublasLtMatmulDesc_t matmul_desc_;
532-
cublasLtMatrixLayout_t A_desc_;
533-
cublasLtMatrixLayout_t B_desc_;
534-
cublasLtMatrixLayout_t C_desc_;
535-
int32_t alpha_ = 1;
536-
int32_t beta_ = 0;
532+
cublasLtMatmulDesc_t matmul_desc;
533+
cublasLtMatrixLayout_t A_desc;
534+
cublasLtMatrixLayout_t B_desc;
535+
cublasLtMatrixLayout_t C_desc;
536+
int32_t alpha = 1;
537+
int32_t beta = 0;
537538

538539
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
539540
CUDA_CHECK(
540-
cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType, CUDA_R_32I));
541+
cublasLtMatmulDescCreate(&matmul_desc, cudaComputeType, CUDA_R_32I));
541542
cublasOperation_t op_transpose = CUBLAS_OP_T;
542-
CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc_,
543+
CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc,
543544
CUBLASLT_MATMUL_DESC_TRANSA,
544545
&op_transpose,
545546
sizeof(op_transpose)));
546-
CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k));
547-
CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k));
548-
CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n));
547+
CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc, CUDA_R_8I, k, n, k));
548+
CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, k, m, k));
549+
CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc, CUDA_R_32I, n, m, n));
549550

550551
cublasLtMatmulAlgo_t algo;
551552
int algoId;
@@ -574,17 +575,17 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
574575
if (is_test) {
575576
std::vector<algoSelect_t> algos;
576577
// Select //
577-
FindAlgo(dev_ctx.handle_,
578+
FindAlgo(dev_ctx.handle,
578579
m,
579580
n,
580581
k,
581582
B_dev,
582583
A_dev,
583584
C_dev,
584-
matmul_desc_,
585-
B_desc_,
586-
A_desc_,
587-
C_desc_,
585+
matmul_desc,
586+
B_desc,
587+
A_desc,
588+
C_desc,
588589
CUBLAS_COMPUTE_32I,
589590
CUDA_R_32I,
590591
CUDA_R_8I,
@@ -643,7 +644,7 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
643644
paddle::DataType::UINT8,
644645
paddle::GPUPlace());
645646
void* workspace_ptr = workspace.data<uint8_t>();
646-
CUDA_CHECK(cublasLtMatmulAlgoInit(dev_ctx.handle_,
647+
CUDA_CHECK(cublasLtMatmulAlgoInit(dev_ctx.handle,
647648
cudaComputeType,
648649
CUDA_R_32I,
649650
CUDA_R_8I,
@@ -677,18 +678,18 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
677678
auto start = std::chrono::high_resolution_clock::now();
678679
const int repeats = 10;
679680
for (int loop = 0; loop < repeats; loop++) {
680-
CUDA_CHECK(cublasLtMatmul(dev_ctx.handle_,
681-
matmul_desc_,
682-
&alpha_,
681+
CUDA_CHECK(cublasLtMatmul(dev_ctx.handle,
682+
matmul_desc,
683+
&alpha,
683684
B_dev,
684-
B_desc_,
685+
B_desc,
685686
A_dev,
686-
A_desc_,
687-
&beta_,
687+
A_desc,
688+
&beta,
688689
C_dev,
689-
C_desc_,
690+
C_desc,
690691
C_dev,
691-
C_desc_,
692+
C_desc,
692693
&algo,
693694
// nullptr,
694695
workspace_ptr,
@@ -711,8 +712,8 @@ void TuneCublasltGemm(const paddle::Tensor& M,
711712
bool is_test,
712713
bool is_read_from_file,
713714
const std::string& path) {
714-
715-
// Ensure that M, K, and N are all one-dimensional Tensors. is_test != is_read_from_file
715+
// Ensure that M, K, and N are all one-dimensional Tensors. is_test !=
716+
// is_read_from_file
716717
assert(M.dims().size() == 1 && K.dims().size() == 1 && N.dims().size() == 1);
717718
assert(is_test != is_read_from_file);
718719

@@ -730,22 +731,34 @@ void TuneCublasltGemm(const paddle::Tensor& M,
730731

731732
int m_data = (int)M_data[0];
732733
assert(m_data > 0 && 4 <= 8192);
733-
734+
734735
std::vector<int> mm;
735736

736737
int m = 1, step = 1;
737-
while (m <= m_data) {
738+
while (m <= m_data) {
738739
mm.push_back(m);
739740
m += step;
740741

741742
// update step
742743
switch (m) {
743-
case 4: step = 4; break;
744-
case 16: step = 16; break;
745-
case 64: step = 32; break;
746-
case 256: step = 64; break;
747-
case 512: step = 128; break;
748-
case 1024: step = 1024; break;
744+
case 4:
745+
step = 4;
746+
break;
747+
case 16:
748+
step = 16;
749+
break;
750+
case 64:
751+
step = 32;
752+
break;
753+
case 256:
754+
step = 64;
755+
break;
756+
case 512:
757+
step = 128;
758+
break;
759+
case 1024:
760+
step = 1024;
761+
break;
749762
}
750763
}
751764

@@ -761,15 +774,15 @@ void TuneCublasltGemm(const paddle::Tensor& M,
761774
if (dtype == "int8") {
762775
CUBLASLTContext dev_ctx;
763776
GEMMInt8(dev_ctx,
764-
A,
765-
B,
766-
C,
767-
m,
768-
k,
769-
n,
770-
is_test, /*is_test*/
771-
is_read_from_file, /*is_read_from_file*/
772-
path);
777+
A,
778+
B,
779+
C,
780+
m,
781+
k,
782+
n,
783+
is_test, /*is_test*/
784+
is_read_from_file, /*is_read_from_file*/
785+
path);
773786
} else {
774787
// other dtype
775788
std::cout << "Not currently supported" << std::endl;

0 commit comments

Comments
 (0)