Skip to content

Commit abb0d3c

Browse files
authored
[Bug Fix] fix paddle multipy_fwd_func warning message (#7818)
1 parent c5d8d5b commit abb0d3c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,12 @@ def forward(self, hidden_states):
366366

367367
if paddle.in_dynamic_mode():
368368
with paddle.amp.auto_cast(False):
369-
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
369+
hidden_states = hidden_states.astype("float32")
370+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
370371
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
371372
else:
372-
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
373+
hidden_states = hidden_states.astype("float32")
374+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
373375
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
374376

375377
if self.weight.dtype in [paddle.float16, paddle.bfloat16]:

0 commit comments

Comments
 (0)