@@ -96,7 +96,7 @@ def swiglu(x, y=None):
9696 "LlamaForCausalLM" ,
9797 "LlamaPretrainingCriterion" ,
9898]
99- global npu_is_casual
99+
100100npu_is_casual = False
101101
102102def _get_interleave (n ):
@@ -213,7 +213,7 @@ def scaled_dot_product_attention(
213213):
214214 bsz , q_len , num_heads , head_dim = query_states .shape
215215 _ , kv_seq_len , _ , _ = value_states .shape
216- global npu_is_casual
216+
217217 if config .use_flash_attention and flash_attention :
218218 # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
219219 # Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
@@ -1613,7 +1613,6 @@ def forward(
16131613 attention_mask = self ._prepare_decoder_attention_mask (
16141614 attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
16151615 ) # [bs, 1, seq_len, seq_len]
1616- global npu_is_casual
16171616 if self .config .use_flash_attention :
16181617 is_casual = is_casual_mask (attention_mask )
16191618 if get_env_device () != "npu" :
0 commit comments