File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
paddlenlp/transformers/llama Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -366,10 +366,12 @@ def forward(self, hidden_states):
366366
367367 if paddle .in_dynamic_mode ():
368368 with paddle .amp .auto_cast (False ):
369- variance = hidden_states .astype ("float32" ).pow (2 ).mean (- 1 , keepdim = True )
369+ hidden_states = hidden_states .astype ("float32" )
370+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
370371 hidden_states = paddle .rsqrt (variance + self .variance_epsilon ) * hidden_states
371372 else :
372- variance = hidden_states .astype ("float32" ).pow (2 ).mean (- 1 , keepdim = True )
373+ hidden_states = hidden_states .astype ("float32" )
374+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
373375 hidden_states = paddle .rsqrt (variance + self .variance_epsilon ) * hidden_states
374376
375377 if self .weight .dtype in [paddle .float16 , paddle .bfloat16 ]:
You can’t perform that action at this time.
0 commit comments