Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions paddle/phi/kernels/gpu/group_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
if (x_scale != static_cast<T>(0)) x_scale_inv = static_cast<T>(1.0) / x_scale;
AccT d_mean_data = static_cast<AccT>(0);
AccT d_var_data = static_cast<AccT>(0);
T d_scale_data = static_cast<T>(0);
T d_bias_data = static_cast<T>(0);
AccT d_scale_data = static_cast<AccT>(0);
AccT d_bias_data = static_cast<AccT>(0);

for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
AccT val, dval;
Expand All @@ -67,8 +67,8 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
d_mean_data += dval * static_cast<AccT>(x_scale);

val = val * static_cast<AccT>(x_scale_inv);
d_bias_data += static_cast<T>(dval);
d_scale_data += static_cast<T>(val * dval);
d_bias_data += dval;
d_scale_data += val * dval;
}
CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]),
static_cast<AccT>(d_mean_data));
Expand All @@ -77,16 +77,16 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,

if (flags & kHasScale) {
#if CUDA_VERSION >= 11070
phi::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data);
phi::CudaAtomicAdd(&(d_scale[ccid]), static_cast<T>(d_scale_data));
#else
CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
CudaAtomicAddWithWarp(&(d_scale[ccid]), static_cast<T>(d_scale_data));
#endif
}
if (flags & kHasBias) {
#if CUDA_VERSION >= 11070
phi::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data);
phi::CudaAtomicAdd(&(d_bias[ccid]), static_cast<T>(d_bias_data));
#else
CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
CudaAtomicAddWithWarp(&(d_bias[ccid]), static_cast<T>(d_bias_data));
#endif
}
}
Expand Down Expand Up @@ -128,7 +128,7 @@ __global__ void GroupNormBackward(const T* x,
: static_cast<AccT>(1);
AccT x_bias =
(flags & kHasBias) ? static_cast<AccT>(bias[ccid]) : static_cast<AccT>(0);
AccT x_scale_inv = static_cast<T>(0);
AccT x_scale_inv = static_cast<AccT>(0);
if (x_scale != static_cast<AccT>(0))
x_scale_inv = static_cast<AccT>(1.0) / x_scale;

Expand Down Expand Up @@ -220,7 +220,7 @@ __global__ void GetBackwardParamsCUDAKernel(int imsize,
sum1 += static_cast<AccT>(ds[index]) * scale_v;
sum2 += static_cast<AccT>(db[index]) * scale_v;
const AccT scale_c =
scale == nullptr ? static_cast<AccT>(0) : static_cast<T>(scale[c]);
scale == nullptr ? static_cast<AccT>(0) : static_cast<AccT>(scale[c]);
p1[index] = static_cast<AccT>(scale_c) * var_inv;
}

Expand Down Expand Up @@ -402,7 +402,7 @@ void GroupNormGradKernel(const Context& dev_ctx,
p1_data,
p2_data,
p3_data);
GetXGradientCUDAKernel<T>
GetXGradientCUDAKernel<T, AccT>
<<<grid, threads, 0, dev_ctx.stream()>>>(imsize,
C,
group_size,
Expand All @@ -424,7 +424,7 @@ void GroupNormGradKernel(const Context& dev_ctx,

DenseTensor temp_var;
temp_var.Resize(var.dims());
dev_ctx.template Alloc<T>(&temp_var);
dev_ctx.template Alloc<AccT>(&temp_var);
set_zero_AccT(dev_ctx, &temp_var, static_cast<AccT>(0));
auto* temp_var_data = temp_var.data<AccT>();

Expand Down Expand Up @@ -483,4 +483,5 @@ PD_REGISTER_KERNEL(group_norm_grad,
phi::GroupNormGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
7 changes: 6 additions & 1 deletion paddle/phi/kernels/gpu/group_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/group_norm_utils.h"

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/device_context.h"

namespace phi {

template <typename T, typename AccT>
Expand Down Expand Up @@ -124,7 +128,7 @@ void GroupNormKernel(const Context& dev_ctx,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* var) {
using AccT = typename kps::details::MPTypeTrait<T>::Type;
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
Expand Down Expand Up @@ -342,4 +346,5 @@ PD_REGISTER_KERNEL(group_norm,
phi::GroupNormKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
125 changes: 122 additions & 3 deletions python/paddle/fluid/tests/unittests/test_group_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
from testsuite import create_op

import paddle
Expand Down Expand Up @@ -94,8 +94,8 @@ def setUp(self):
self.attrs['data_layout'] = self.data_format

def test_check_output(self):
atol = 0.0
inplace_atol = 0.0
atol = 0
inplace_atol = 0
place = core.CPUPlace()

self.check_output_with_place(place, atol=atol)
Expand Down Expand Up @@ -161,16 +161,133 @@ def init_test_case(self):
pass


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_float16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestGroupNormFP16OP(TestGroupNormOp):
def test_check_output(self):
atol = 1e-3
inplace_atol = 1e-3

place = core.CUDAPlace(0)
# group_norm uses AtomicAdd on CUDAPlace, which do not ensure
# computation order when multiple threads write the same address. So the
# result of group_norm is non-deterministic when datatype is float.
# When inplace_atol is not None, the inplace check uses numpy.allclose
# to check inplace result instead of numpy.array_equal.
# Set to inplace_atol to 0, which means the absolute error is 0, and the
# relative error is 1e-05 in numpy.allclose by default.
# Reference: https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html
self.check_output_with_place(place)

def test_check_grad(self):
if self.compare_between_place:
return

place = core.CUDAPlace(0)
self.check_grad_with_place(place, set(['X', 'Scale', 'Bias']), 'Y')

def init_test_case(self):
self.dtype = np.float16


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestGroupNormBF16Op(OpTest):
def setUp(self):
self.op_type = "group_norm"
self.data_format = "NCHW"
self.dtype = np.uint16
self.shape = (2, 100, 3, 5)
self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"}
self.compare_between_place = False
self.init_test_case()

input = np.random.random(self.shape).astype(np.float32)
if self.data_format == "NHWC":
input = np.transpose(input, (0, 2, 3, 1))
scale = np.random.random([self.shape[1]]).astype(np.float32)
bias = np.random.random([self.shape[1]]).astype(np.float32)
output, mean, var = group_norm_naive(
input,
scale,
bias,
self.attrs['epsilon'],
self.attrs['groups'],
self.data_format,
)

self.inputs = {
'X': convert_float_to_uint16(input),
'Scale': convert_float_to_uint16(scale),
'Bias': convert_float_to_uint16(bias),
}
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
self.attrs['data_layout'] = self.data_format

def test_check_output(self):
atol = 1e-2
inplace_atol = 1e-2

place = core.CUDAPlace(0)
# group_norm uses AtomicAdd on CUDAPlace, which do not ensure
# computation order when multiple threads write the same address. So the
# result of group_norm is non-deterministic when datatype is float.
# When inplace_atol is not None, the inplace check uses numpy.allclose
# to check inplace result instead of numpy.array_equal.
# Set to inplace_atol to 0, which means the absolute error is 0, and the
# relative error is 1e-05 in numpy.allclose by default.
# Reference: https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html
self.check_output_with_place(place)

def test_check_grad(self):
if self.compare_between_place:
return

place = core.CUDAPlace(0)
self.check_grad_with_place(place, set(['X', 'Scale', 'Bias']), 'Y')

def init_test_case(self):
pass


class TestGroupNormOp1(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 1


class TestGroupNormFP16Op1(TestGroupNormFP16OP):
def init_test_case(self):
self.attrs['groups'] = 1
self.dtype = np.float16


class TestGroupNormBF16Op1(TestGroupNormBF16Op):
def init_test_case(self):
self.attrs['groups'] = 1


class TestGroupNormOp2(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 4


class TestGroupNormFP16Op2(TestGroupNormFP16OP):
def init_test_case(self):
self.attrs['groups'] = 4
self.dtype = np.float16


class TestGroupNormBF16Op2(TestGroupNormBF16Op):
def init_test_case(self):
self.attrs['groups'] = 4


class TestGroupNormOpBigEps1(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 1
Expand Down Expand Up @@ -244,6 +361,8 @@ def init_test_case(self):


class TestGroupNormAPI_With_NHWC(unittest.TestCase):
paddle.enable_static()

def test_case1(self):
data1 = fluid.data(name='data1', shape=[None, 3, 3, 4], dtype='float64')
out1 = paddle.static.nn.group_norm(
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/static/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,10 @@ def group_norm(
helper = LayerHelper('group_norm', **locals())
dtype = helper.input_dtype()
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'group_norm'
input,
'input',
['float16', 'uint16', 'float32', 'float64'],
'group_norm',
)
# create intput and parameters
inputs = {'X': input}
Expand Down