Skip to content

Commit b8871d7

Browse files
authored
fixed bugs in batchnorm backward decomp (#69688)
1 parent 27b5a28 commit b8871d7

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

paddle/fluid/primitive/decomp_utils/decomp_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class BatchNormDecompHelper {
447447
} else {
448448
auto x_shape = shape<T>(x);
449449
auto nhw = get_slice<T>(x_shape, 0);
450-
for (int64_t i = 0; i < x_rank_; ++i) {
450+
for (int64_t i = 1; i < x_rank_; ++i) {
451451
if (i == channel_axis_) {
452452
continue;
453453
}

test/prim/pir_prim/test_prim_sub_graph_abcde_backward_dynamic_shape.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ def batch_norm_net4(x, y, z):
6969
)
7070

7171

72+
def batch_norm_net5(x, y, z):
73+
var = paddle.ones([40], dtype="float32")
74+
mean = paddle.zeros([40], dtype='float32')
75+
return paddle.nn.functional.batch_norm(
76+
x, mean, var, y, z, use_global_stats=False, training=True
77+
)
78+
79+
7280
def ceil_net(x):
7381
return paddle.ceil(x)
7482

@@ -524,6 +532,25 @@ def setUp(self):
524532
self.tol = 1e-5
525533

526534

535+
class TestPrimBatchNormWithGrad12(TestPrimThreeWithGrad):
536+
def setUp(self):
537+
np.random.seed(2023)
538+
self.op_name = "pd_op.batch_norm_grad"
539+
self.dtype = "float32"
540+
self.x_shape = [30, 40]
541+
self.init_x_shape = [None, None]
542+
self.y_shape = [40]
543+
self.init_y_shape = [None]
544+
self.z_shape = [40]
545+
self.init_z_shape = [None]
546+
self.x = np.random.random(self.x_shape).astype(self.dtype)
547+
self.y = np.random.random(self.y_shape).astype(self.dtype)
548+
self.z = np.random.random(self.z_shape).astype(self.dtype)
549+
self.net = batch_norm_net5
550+
self.enable_cinn = False
551+
self.tol = 1e-5
552+
553+
527554
class TestPrimCeilWithGrad(TestPrimBaseWithGrad):
528555
def setUp(self):
529556
np.random.seed(2024)

test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,27 @@ def setUp(self):
11431143
self.tol = 1e-5
11441144

11451145

1146+
class TestPrimBatchNormNC(TestPrimThree):
1147+
def setUp(self):
1148+
np.random.seed(2023)
1149+
self.shape_x = [30, 40]
1150+
self.shape_y = [40]
1151+
self.shape_z = [40]
1152+
self.dtype_x = "float32"
1153+
self.dtype_y = "float32"
1154+
self.dtype_z = "float32"
1155+
self.init_x_shape = [None, None]
1156+
self.init_y_shape = [None]
1157+
self.init_z_shape = [None]
1158+
self.x = np.random.random(self.shape_x).astype(self.dtype_x)
1159+
self.y = np.random.random(self.shape_y).astype(self.dtype_y)
1160+
self.z = np.random.random(self.shape_z).astype(self.dtype_z)
1161+
self.net = batch_norm_net2
1162+
self.necessary_ops = "pd_op.batch_norm_"
1163+
self.enable_cinn = False
1164+
self.tol = 1e-5
1165+
1166+
11461167
class TestPrimLogLoss1(TestPrimTwo):
11471168
def setUp(self):
11481169
np.random.seed(2023)

0 commit comments

Comments
 (0)