Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/framework/pten_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
}

KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(op_proto_->type(), GetInputArgsNames(),
GetAttrsArgsNames(), GetOutputArgsNames());
return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()),
GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
}

std::once_flag kernel_sig_map_init_flag;
Expand Down
243 changes: 243 additions & 0 deletions paddle/fluid/operators/math/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,102 @@ inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 8000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
Expand Down Expand Up @@ -1208,6 +1304,42 @@ inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
}
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
bool trans_a, int M, int N, platform::bfloat16 alpha,
const platform::bfloat16 *A, const platform::bfloat16 *B,
platform::bfloat16 beta, platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}

template <>
template <>
inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
Expand Down Expand Up @@ -1306,6 +1438,91 @@ void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 9010
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
Expand Down Expand Up @@ -1356,6 +1573,32 @@ inline void Blas<pten::GPUContext>::BatchedGEMM(
}
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}

template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
Expand Down
Loading