Skip to content

Commit 38afd03

Browse files
authored
[DCU] fix some faild ut (#65716)
1 parent baea4ab commit 38afd03

File tree

4 files changed

+28
-98
lines changed

4 files changed

+28
-98
lines changed

paddle/phi/kernels/gpu/check_numerics_kernel.cu

Lines changed: 1 addition & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -58,82 +58,6 @@ static void InitMultiGPUOpVarMap() {
5858
multi_op_var2gpu_str_mutex().swap(tmp_multi_mutex);
5959
}
6060

61-
template <typename T>
62-
__device__ __forceinline__ void PrintNanInfKernel(const T* value,
63-
const size_t numel,
64-
int print_num,
65-
char* debug_info) {
66-
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
67-
68-
__shared__ unsigned int nan_count, inf_count, num_count;
69-
if (threadIdx.x == 0) nan_count = inf_count = num_count = 0;
70-
__syncthreads;
71-
72-
for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
73-
unsigned int count = 0;
74-
if (isnan(value[i])) {
75-
count = atomicAdd(&nan_count, 1);
76-
} else if (isinf(value[i])) {
77-
count = atomicAdd(&inf_count, 1);
78-
} else {
79-
count = atomicAdd(&num_count, 1);
80-
}
81-
// for cuda, print in every block
82-
if (count < print_num) {
83-
printf("numel:%lu idx:%lu value:%f\n",
84-
static_cast<uint64_t>(numel),
85-
static_cast<uint64_t>(i),
86-
static_cast<float>(value[i]));
87-
}
88-
}
89-
__syncthreads;
90-
91-
#ifdef __HIPCC__
92-
if (true && hipThreadIdx_x == 0) {
93-
printf("In block %d, there has %u,%u,%u nan,inf,num\n",
94-
hipBlockIdx_x,
95-
nan_count,
96-
inf_count,
97-
num_count);
98-
#else
99-
if (true && threadIdx.x == 0) {
100-
printf("In block %d, there has %u,%u,%u nan,inf,num\n",
101-
blockIdx.x,
102-
nan_count,
103-
inf_count,
104-
num_count);
105-
#endif
106-
PADDLE_ENFORCE(false, "===ERROR: in %s find nan or inf===", debug_info);
107-
}
108-
}
109-
110-
// Resnet 2gpus speed test, no check 270 images/s, this check 229 images/s
111-
template <typename T>
112-
__global__ void CheckNanInfKernel(const T* value,
113-
const size_t numel,
114-
int print_num,
115-
char* debug_info) {
116-
/// step 1, judge wheater has nan or inf
117-
__shared__ volatile int has_nan_inf;
118-
if (threadIdx.x == 0) has_nan_inf = false;
119-
__syncthreads();
120-
121-
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
122-
T sum = static_cast<T>(0.0);
123-
// Todo(wangxi). simd speed up
124-
for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
125-
sum += (value[i] - value[i]);
126-
}
127-
128-
if (isnan(sum) || isinf(sum)) has_nan_inf = true;
129-
__syncthreads();
130-
131-
/// Note. different blocks may behave differently
132-
if (!has_nan_inf) return;
133-
134-
PrintNanInfKernel(value, numel, print_num, debug_info);
135-
}
136-
13761
template <typename T, int ReduceType>
13862
__device__ T BlockReduce(T value) {
13963
__shared__ T shared_mem[1024];
@@ -509,19 +433,7 @@ void CheckNumericsKernel(const Context& ctx,
509433
size_t blocks =
510434
std::min(static_cast<size_t>(128),
511435
static_cast<size_t>((tensor.numel() + threads - 1) / threads));
512-
#ifdef __HIPCC__
513-
int print_num = 3;
514-
515-
hipLaunchKernelGGL(CheckNanInfKernel,
516-
dim3(blocks),
517-
dim3(threads),
518-
0,
519-
ctx.stream(),
520-
tensor.data<T>(),
521-
tensor.numel(),
522-
print_num,
523-
gpu_str_ptr);
524-
#else
436+
525437
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
526438

527439
int64_t numel_max_min = blocks;
@@ -586,7 +498,6 @@ void CheckNumericsKernel(const Context& ctx,
586498
if (check_nan_inf_level == 0 && stack_height_limit > 0) {
587499
PrintStack<T>(ctx, *stats, op_type, var_name, dev_id);
588500
}
589-
#endif
590501
}
591502

592503
} // namespace phi

paddle/phi/kernels/gpu/group_norm_kernel.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,8 +880,23 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x,
880880
}
881881
x_mean /= number * imsize;
882882
x_var /= number * imsize;
883+
884+
#ifdef __NVCC__
883885
CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean);
884886
CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
887+
#endif
888+
#ifdef __HIPCC__
889+
// Note(wangyanpeng04): When the block size is less than the warp size,
890+
// WarpReduce will result in all zeros. It seems to be an internal problem of
891+
// hipcub on DCU.
892+
if (blockDim.x < phi::kps::details::kWarpSize) {
893+
phi::CudaAtomicAdd(&mean[bid * groups + gid], x_mean);
894+
phi::CudaAtomicAdd(&var[bid * groups + gid], x_var);
895+
} else {
896+
CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean);
897+
CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
898+
}
899+
#endif
885900
}
886901

887902
template <typename T, typename AccT, int flags>

paddle/phi/kernels/strings/gpu/strings_lower_upper_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ struct AsciiCaseConverter<phi::GPUContext, CharConverter> {
4141
const pstring* in,
4242
pstring* out,
4343
size_t num) const {
44+
#ifdef PADDLE_WITH_HIP
45+
dim3 block_size = dim3(256, 1);
46+
#else
4447
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
48+
#endif
4549
dim3 grid_size =
4650
dim3((num + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1);
4751
StringCaseConvertCUDAKernel<CharConverter>

test/cpp/fluid/elementwise/test_elementwise_op_grad_grad.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,6 @@ class TestElementwiseOpGradGrad {
128128
}
129129
auto *out_ptr = cpu_out.data<T>();
130130
size_t numel = static_cast<size_t>(common::product(dims_));
131-
#ifdef PADDLE_WITH_HIP
132-
auto is_equal = std::equal(
133-
out_ptr,
134-
out_ptr + numel,
135-
expected_outs_[out_name].data(),
136-
[](const float &l, const float &r) { return fabs(l - r) < 1e-8; });
137-
#else
138131
bool is_equal;
139132
if (op_type_ == "elementwise_div_grad_grad") {
140133
is_equal = std::equal(out_ptr,
@@ -144,10 +137,17 @@ class TestElementwiseOpGradGrad {
144137
return fabs(l - r) < 0.0005;
145138
});
146139
} else {
140+
#ifdef PADDLE_WITH_HIP
141+
is_equal = std::equal(
142+
out_ptr,
143+
out_ptr + numel,
144+
expected_outs_[out_name].data(),
145+
[](const float &l, const float &r) { return fabs(l - r) < 1e-8; });
146+
#else
147147
is_equal = std::equal(
148148
out_ptr, out_ptr + numel, expected_outs_[out_name].data());
149-
}
150149
#endif
150+
}
151151
if (!is_equal) {
152152
all_equal = false;
153153
break;

0 commit comments

Comments
 (0)