Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle/phi/backends/gpu/cuda/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include <cuda_runtime.h> // NOLINT

#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"

namespace phi {
namespace backends {
Expand Down Expand Up @@ -87,6 +89,11 @@ cudaDataType_t ToCudaDataType() {
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
return CUDA_R_16BF;
#endif
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"DataType %s is unsupported for CUDA.",
paddle::experimental::DataTypeToString(
paddle::experimental::CppTypeToDataType<T>::Type())));
}
}

Expand Down
18 changes: 3 additions & 15 deletions paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */

#include <cuda_runtime_api.h>
#include "cuda.h" // NOLINT
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"
Expand All @@ -27,19 +28,6 @@ namespace funcs {

enum MatmulImplType { kImplWithCublas = 1, kImplWithCublasLt = 2 };

template <typename T>
cudaDataType_t ConvertToCudaDataType() {
if (std::is_same<T, float>::value) {
return CUDA_R_32F;
} else if (std::is_same<T, double>::value) {
return CUDA_R_64F;
} else if (std::is_same<T, phi::dtype::float16>::value) {
return CUDA_R_16F;
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
return CUDA_R_16BF;
}
}

template <typename T>
cublasComputeType_t GetCudaComputeType() {
if (std::is_same<T, double>::value) {
Expand Down Expand Up @@ -68,8 +56,8 @@ struct MatmulDescriptor {
int64_t stride_out = 0) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;

cudaDataType_t mat_type = ConvertToCudaDataType<T>();
cudaDataType_t scale_type = ConvertToCudaDataType<MT>();
cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
cublasComputeType_t compute_type = GetCudaComputeType<T>();

// Create operation desciriptor; see cublasLtMatmulDescAttributes_t for
Expand Down