Skip to content

Commit eb12e62

Browse files
fix eval branch of prim vjp of batch_norm in amp mode (#53598)
1 parent aec4e38 commit eb12e62

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

paddle/fluid/prim/api/composite_backward/composite_backward_api.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,9 @@ void batch_norm_grad(const Tensor& x,
16991699
if (use_global_stats) {
17001700
auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad;
17011701
auto nchw_x_grad = transpose<T>(nhwc_x_grad, nhwc_to_nchw_dim);
1702+
if (x.dtype() == phi::DataType::FLOAT16) {
1703+
nchw_x_grad = cast<T>(nchw_x_grad, x.dtype());
1704+
}
17021705
set_output<T>(nchw_x_grad, x_grad);
17031706
} else {
17041707
auto part1 = scale * rsqrt_var;
@@ -1732,6 +1735,9 @@ void batch_norm_grad(const Tensor& x,
17321735
sum<T>(out_grad_data * x_sub_mean, reduce_axis, dtype, false);
17331736
if (use_global_stats) {
17341737
auto x_grad_data = scale * rsqrt_var * out_grad_data;
1738+
if (x.dtype() == phi::DataType::FLOAT16) {
1739+
x_grad_data = cast<T>(x_grad_data, x.dtype());
1740+
}
17351741
set_output<T>(x_grad_data, x_grad);
17361742
} else {
17371743
auto part1 = scale * rsqrt_var;

test/prim/composite_ops/test_composite_batch_norm.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,12 @@ def apply_to_static(net, use_cinn):
386386

387387

388388
class PrimeNet(paddle.nn.Layer):
389-
def __init__(self, data_layout='NCHW'):
389+
def __init__(self, data_layout='NCHW', is_test=False):
390390
super().__init__()
391391
self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False)
392-
self.bn = BatchNorm(4, act="relu", data_layout=data_layout)
392+
self.bn = BatchNorm(
393+
4, act="relu", data_layout=data_layout, is_test=is_test
394+
)
393395

394396
def forward(self, x):
395397
y = self.conv(x)
@@ -408,10 +410,10 @@ def setUp(self):
408410
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
409411
self.x.stop_gradient = False
410412

411-
def train(self, use_prim, data_layout="NCHW"):
413+
def train(self, use_prim, data_layout="NCHW", is_test=False):
412414
core._set_prim_all_enabled(use_prim)
413415
paddle.seed(2022)
414-
net = PrimeNet(data_layout)
416+
net = PrimeNet(data_layout=data_layout, is_test=is_test)
415417
sgd = paddle.optimizer.SGD(
416418
learning_rate=0.1, parameters=net.parameters()
417419
)
@@ -429,8 +431,19 @@ def train(self, use_prim, data_layout="NCHW"):
429431

430432
def test_amp_nchw(self):
431433
if not isinstance(framework._current_expected_place(), core.CPUPlace):
432-
expected = self.train(False)
433-
actual = self.train(True)
434+
expected = self.train(use_prim=False)
435+
actual = self.train(use_prim=True)
436+
np.testing.assert_allclose(
437+
expected,
438+
actual,
439+
rtol=1e-3,
440+
atol=1e-3,
441+
)
442+
443+
def test_amp_nchw_eval(self):
444+
if not isinstance(framework._current_expected_place(), core.CPUPlace):
445+
expected = self.train(use_prim=False, is_test=True)
446+
actual = self.train(use_prim=True, is_test=True)
434447
np.testing.assert_allclose(
435448
expected,
436449
actual,
@@ -449,6 +462,19 @@ def test_amp_nhwc(self):
449462
atol=1e-3,
450463
)
451464

465+
def test_amp_nhwc_eval(self):
466+
if not isinstance(framework._current_expected_place(), core.CPUPlace):
467+
expected = self.train(
468+
use_prim=False, data_layout="NHWC", is_test=True
469+
)
470+
actual = self.train(use_prim=True, data_layout="NHWC", is_test=True)
471+
np.testing.assert_allclose(
472+
expected,
473+
actual,
474+
rtol=1e-3,
475+
atol=1e-3,
476+
)
477+
452478

453479
class TestPrimEvalBranch(unittest.TestCase):
454480
"""

0 commit comments

Comments
 (0)