Skip to content

Commit 67f3613

Browse files
authored
fix fused_layer_norm fused_rms_norm outputs (#69960)
* fix fused_layer_norm fused_rms_norm outputs
1 parent 19ba181 commit 67f3613

File tree

4 files changed

+43
-74
lines changed

4 files changed

+43
-74
lines changed

python/paddle/incubate/nn/functional/fused_layer_norm.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import paddle
2020
from paddle import _C_ops
21-
from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode
21+
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode
2222

2323
if TYPE_CHECKING:
2424
from paddle import Tensor
@@ -108,8 +108,7 @@ def fused_layer_norm(
108108
>>> epsilon = 1e-6
109109
>>> paddle_layernorm = paddle.incubate.nn.functional.fused_layer_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
110110
"""
111-
112-
if in_dynamic_mode():
111+
if in_dynamic_or_pir_mode():
113112
return _C_ops.fused_bias_residual_layernorm(
114113
x,
115114
bias,
@@ -124,23 +123,7 @@ def fused_layer_norm(
124123
quant_max_bound,
125124
quant_min_bound,
126125
)
127-
elif in_pir_mode():
128-
out, residual_out, _, _ = _C_ops.fused_bias_residual_layernorm(
129-
x,
130-
bias,
131-
residual,
132-
norm_weight,
133-
norm_bias,
134-
epsilon,
135-
residual_alpha,
136-
begin_norm_axis,
137-
quant_scale,
138-
quant_round_type,
139-
quant_max_bound,
140-
quant_min_bound,
141-
)
142-
return (out, residual_out) if residual is not None else out
143-
126+
# static mode
144127
helper = LayerHelper('fused_layernorm', **locals())
145128
out = None
146129
if quant_scale <= 0:
@@ -183,4 +166,4 @@ def fused_layer_norm(
183166
},
184167
outputs=outputs_dict,
185168
)
186-
return (out, residual_out) if residual is not None else out
169+
return (out, residual_out, outputs_dict['mean'], outputs_dict['variance'])

python/paddle/incubate/nn/functional/fused_rms_norm.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import paddle
2020
from paddle import _C_ops
21-
from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode
21+
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode
2222

2323
if TYPE_CHECKING:
2424
from paddle import Tensor
@@ -102,7 +102,7 @@ def fused_rms_norm(
102102
>>> epsilon = 1e-6
103103
>>> paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
104104
"""
105-
if in_dynamic_mode():
105+
if in_dynamic_or_pir_mode():
106106
return _C_ops.rms_norm(
107107
x,
108108
bias,
@@ -116,21 +116,7 @@ def fused_rms_norm(
116116
quant_max_bound,
117117
quant_min_bound,
118118
)
119-
if in_pir_mode():
120-
out, residual_out = _C_ops.rms_norm(
121-
x,
122-
bias,
123-
residual,
124-
norm_weight,
125-
norm_bias,
126-
epsilon,
127-
begin_norm_axis,
128-
quant_scale,
129-
quant_round_type,
130-
quant_max_bound,
131-
quant_min_bound,
132-
)
133-
return (out, residual_out) if residual is not None else out
119+
# static mode
134120
helper = LayerHelper('rms_norm', **locals())
135121
out = None
136122
if quant_scale <= 0:
@@ -167,4 +153,4 @@ def fused_rms_norm(
167153
},
168154
outputs=outputs_dict,
169155
)
170-
return (out, residual_out) if residual is not None else out
156+
return (out, residual_out, outputs_dict['inv_var'])

test/legacy_test/test_fused_layernorm_op.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def check_layernorm(self, x_np, gamma_np, beta_np, dtype):
448448
beta_static,
449449
self.epsilon,
450450
begin_norm_axis=1,
451-
)
451+
)[0]
452452
exe = paddle.static.Executor(self.place)
453453
out_s = exe.run(
454454
feed={
@@ -498,7 +498,7 @@ def check_layernorm_int8(self, x_np, gamma_np, beta_np, dtype):
498498
quant_round_type=self.quant_round_type,
499499
quant_max_bound=self.quant_max_bound,
500500
quant_min_bound=self.quant_min_bound,
501-
)
501+
)[0]
502502
exe = paddle.static.Executor(self.place)
503503
out_s = exe.run(
504504
feed={
@@ -546,7 +546,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
546546
quant_round_type=self.quant_round_type,
547547
quant_max_bound=self.quant_max_bound,
548548
quant_min_bound=self.quant_min_bound,
549-
)
549+
)[0]
550550

551551
exe = paddle.static.Executor(self.place)
552552
out_s = exe.run(
@@ -556,7 +556,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
556556
"bias_static": bias_np.astype(dtype),
557557
},
558558
fetch_list=[
559-
outs[0]
559+
outs
560560
], # NOTE: Only fetch `out`, because `residual_out` will not be initialized if both `norm_weight` and `norm_bias` are None.
561561
)
562562
return out_s, paddle_naive_residual_out
@@ -597,7 +597,7 @@ def check_residual_bias_layernorm(
597597
beta_static = paddle.static.data(
598598
name="beta_static", shape=[self.cols], dtype='float32'
599599
)
600-
outs = paddle.incubate.nn.functional.fused_layer_norm(
600+
outs, residual = paddle.incubate.nn.functional.fused_layer_norm(
601601
x_static,
602602
gamma_static,
603603
beta_static,
@@ -606,7 +606,7 @@ def check_residual_bias_layernorm(
606606
residual_alpha=self.residual_alpha,
607607
bias=bias_static,
608608
residual=residual_static,
609-
)
609+
)[:2]
610610

611611
exe = paddle.static.Executor(self.place)
612612
out_s = exe.run(
@@ -617,7 +617,7 @@ def check_residual_bias_layernorm(
617617
"residual_static": residual_np.astype(dtype),
618618
"bias_static": bias_np.astype(dtype),
619619
},
620-
fetch_list=[outs],
620+
fetch_list=[outs, residual],
621621
)
622622
return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out
623623

@@ -667,7 +667,7 @@ def check_residual_bias_layernorm_int8(
667667
beta_static = paddle.static.data(
668668
name="beta_static", shape=[self.cols], dtype='float32'
669669
)
670-
outs = paddle.incubate.nn.functional.fused_layer_norm(
670+
outs, residual = paddle.incubate.nn.functional.fused_layer_norm(
671671
x_static,
672672
gamma_static,
673673
beta_static,
@@ -680,7 +680,7 @@ def check_residual_bias_layernorm_int8(
680680
quant_round_type=self.quant_round_type,
681681
quant_max_bound=self.quant_max_bound,
682682
quant_min_bound=self.quant_min_bound,
683-
)
683+
)[:2]
684684

685685
exe = paddle.static.Executor(self.place)
686686
out_s = exe.run(
@@ -691,7 +691,7 @@ def check_residual_bias_layernorm_int8(
691691
"residual_static": residual_np.astype(dtype),
692692
"bias_static": bias_np.astype(dtype),
693693
},
694-
fetch_list=[outs],
694+
fetch_list=[outs, residual],
695695
)
696696
return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out
697697

@@ -847,7 +847,7 @@ def check_layernorm(self, x_np, gamma_np, beta_np, dtype):
847847

848848
paddle_layernorm_out = paddle.incubate.nn.functional.fused_layer_norm(
849849
x, gamma, beta, self.epsilon, begin_norm_axis=1
850-
)
850+
)[0]
851851
paddle_naive_layernorm_out = naive_layer_norm(
852852
x, gamma, beta, self.epsilon
853853
)
@@ -869,7 +869,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
869869
bias=bias,
870870
residual=residual,
871871
residual_alpha=self.residual_alpha,
872-
)
872+
)[0]
873873

874874
paddle_naive_residual_out = naive_residual_bias_add(
875875
x, residual, bias, self.residual_alpha
@@ -919,7 +919,7 @@ def test_residual_bias_add(self):
919919
self.x_np, self.residual_np, self.bias_np, 'float32'
920920
)
921921
np.testing.assert_allclose(
922-
paddle_residual_bias_out[0].numpy(),
922+
paddle_residual_bias_out.numpy(),
923923
paddle_naive_residual_bias_out.numpy(),
924924
rtol=1e-3,
925925
atol=1e-3,
@@ -931,7 +931,7 @@ def test_layernorm(self):
931931
)
932932

933933
np.testing.assert_allclose(
934-
paddle_layernorm[0].numpy(),
934+
paddle_layernorm.numpy(),
935935
paddle_naive_layernorm.numpy(),
936936
rtol=1e-3,
937937
atol=1e-3,
@@ -1016,7 +1016,7 @@ def check_layernorm(self, x_np, gamma_np, beta_np, dtype):
10161016
beta_static,
10171017
self.epsilon,
10181018
begin_norm_axis=1,
1019-
)
1019+
)[0]
10201020
exe = paddle.static.Executor(self.place)
10211021
out_s = exe.run(
10221022
feed={
@@ -1060,7 +1060,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
10601060
bias=bias_static,
10611061
residual=residual_static,
10621062
residual_alpha=self.residual_alpha,
1063-
)
1063+
)[0]
10641064

10651065
exe = paddle.static.Executor(self.place)
10661066
out_s = exe.run(
@@ -1070,7 +1070,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
10701070
"bias_static": bias_np.astype(dtype),
10711071
},
10721072
fetch_list=[
1073-
outs[0]
1073+
outs
10741074
], # NOTE: Only fetch `out`, because `residual_out` will not be initialized if both `norm_weight` and `norm_bias` are None.
10751075
)
10761076
return out_s, paddle_naive_residual_out
@@ -1111,7 +1111,7 @@ def check_residual_bias_layernorm(
11111111
beta_static = paddle.static.data(
11121112
name="beta_static", shape=[self.cols], dtype='float32'
11131113
)
1114-
outs = paddle.incubate.nn.functional.fused_layer_norm(
1114+
outs, residual = paddle.incubate.nn.functional.fused_layer_norm(
11151115
x_static,
11161116
gamma_static,
11171117
beta_static,
@@ -1120,7 +1120,7 @@ def check_residual_bias_layernorm(
11201120
residual_alpha=self.residual_alpha,
11211121
bias=bias_static,
11221122
residual=residual_static,
1123-
)
1123+
)[:2]
11241124

11251125
exe = paddle.static.Executor(self.place)
11261126
out_s = exe.run(
@@ -1131,7 +1131,7 @@ def check_residual_bias_layernorm(
11311131
"residual_static": residual_np.astype(dtype),
11321132
"bias_static": bias_np.astype(dtype),
11331133
},
1134-
fetch_list=[outs],
1134+
fetch_list=[outs, residual],
11351135
)
11361136
return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out
11371137

0 commit comments

Comments
 (0)