1515#pragma once
1616
1717#include " paddle/extension.h"
18+ #ifdef PADDLE_WITH_HIP
19+ #include < hip/hip_runtime.h>
20+ #include < hip/hip_fp16.h>
21+ #include < hip/hip_bfloat16.h>
22+ #include < hipcub/hipcub.hpp>
23+ #include < hiprand.h>
24+ #include < hiprand_kernel.h>
25+ namespace cub = hipcub;
26+ #else
1827#include < cub/cub.cuh>
1928#include < curand_kernel.h>
29+ #endif
2030
2131constexpr int kBlockSize = 256 ;
2232constexpr int kNumWaves = 16 ;
2333
34+ #ifdef PADDLE_WITH_HIP
35+ inline hipError_t GetNumBlocks (int64_t n, int * num_blocks) {
36+ int dev;
37+ {
38+ hipError_t err = hipGetDevice (&dev);
39+ if (err != hipSuccess) { return err; }
40+ }
41+ int sm_count;
42+ {
43+ hipError_t err = hipDeviceGetAttribute (&sm_count, hipDeviceAttributeMultiprocessorCount, dev);
44+ if (err != hipSuccess) { return err; }
45+ }
46+ int tpm;
47+ {
48+ hipError_t err = hipDeviceGetAttribute (&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev);
49+ if (err != hipSuccess) { return err; }
50+ }
51+ *num_blocks = std::max<int >(1 , std::min<int64_t >((n + kBlockSize - 1 ) / kBlockSize ,
52+ sm_count * tpm / kBlockSize * kNumWaves ));
53+ return hipSuccess;
54+ }
55+ #else
2456inline cudaError_t GetNumBlocks (int64_t n, int * num_blocks) {
2557 int dev;
2658 {
@@ -41,6 +73,7 @@ inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
4173 sm_count * tpm / kBlockSize * kNumWaves ));
4274 return cudaSuccess;
4375}
76+ #endif
4477
4578template <typename T>
4679__device__ T max_func (const T a, const T b) {
@@ -74,7 +107,11 @@ class PDTraits<paddle::DataType::FLOAT16> {
74107template <>
75108class PDTraits <paddle::DataType::BFLOAT16> {
76109public:
110+ #ifdef PADDLE_WITH_HIP
111+ typedef hip_bfloat16 DataType;
112+ #else
77113 typedef __nv_bfloat16 DataType;
114+ #endif
78115 typedef paddle::bfloat16 data_t ;
79116};
80117
0 commit comments