Skip to content

Commit 180ea46

Browse files
authored
Revert "update (#8359)"
This reverts commit ae0bea9.
1 parent 18e5cee commit 180ea46

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def swiglu(x, y=None):
9696
"LlamaForCausalLM",
9797
"LlamaPretrainingCriterion",
9898
]
99-
global npu_is_casual
99+
100100
npu_is_casual = False
101101

102102
def _get_interleave(n):
@@ -213,7 +213,7 @@ def scaled_dot_product_attention(
213213
):
214214
bsz, q_len, num_heads, head_dim = query_states.shape
215215
_, kv_seq_len, _, _ = value_states.shape
216-
global npu_is_casual
216+
217217
if config.use_flash_attention and flash_attention:
218218
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
219219
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
@@ -1613,7 +1613,6 @@ def forward(
16131613
attention_mask = self._prepare_decoder_attention_mask(
16141614
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
16151615
) # [bs, 1, seq_len, seq_len]
1616-
global npu_is_casual
16171616
if self.config.use_flash_attention:
16181617
is_casual = is_casual_mask(attention_mask)
16191618
if get_env_device() != "npu":

0 commit comments

Comments
 (0)