3737from ..activations import ACT2FN
3838from ..conversion_utils import StateDictNameMapping , init_name_mappings
3939from ..linear_utils import Linear
40+ from ..llama import fusion_ops
4041from ..model_outputs import (
4142 BaseModelOutputWithPast ,
4243 CausalLMOutputWithPast ,
4344 SequenceClassifierOutputWithPast ,
4445 TokenClassifierOutput ,
4546)
4647from ..model_utils import PretrainedModel , register_base_model
47- from ..utils import caculate_llm_flops
48+ from ..utils import caculate_llm_flops , logger
4849from .configuration import Qwen2Config
4950
5051try :
@@ -156,6 +157,7 @@ def scaled_dot_product_attention(
156157 value_states ,
157158 attention_mask ,
158159 output_attentions ,
160+ attn_mask_startend_row_indices = None ,
159161 training = True ,
160162 sequence_parallel = False ,
161163):
@@ -166,32 +168,16 @@ def scaled_dot_product_attention(
166168 # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
167169 # Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
168170
169- version = paddle .version .full_version
170- if version != "0.0.0" and version <= "2.5.2" :
171- attn_output , attn_weights = flash_attention (
172- query_states ,
173- key_states ,
174- value_states ,
175- causal = True ,
176- return_softmax = output_attentions ,
177- )
178- else :
179- attn_output = F .scaled_dot_product_attention (
180- query_states ,
181- key_states ,
182- value_states ,
183- attn_mask = attention_mask ,
184- is_causal = attention_mask is None ,
185- dropout_p = config .attention_dropout if training else 0.0 ,
186- training = training ,
187- )
188- attn_weights = None
189-
190- if sequence_parallel :
191- attn_output = attn_output .reshape ([bsz * q_len , head_dim * num_heads ])
192- else :
193- attn_output = attn_output .reshape ([bsz , q_len , head_dim * num_heads ])
194- return (attn_output , attn_weights ) if output_attentions else attn_output
171+ return fusion_ops .fusion_flash_attention (
172+ query_states ,
173+ config ,
174+ key_states ,
175+ value_states ,
176+ attention_mask ,
177+ output_attentions ,
178+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
179+ sequence_parallel = sequence_parallel ,
180+ )
195181 else :
196182 # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
197183 query_states = paddle .transpose (query_states , [0 , 2 , 1 , 3 ])
@@ -510,6 +496,7 @@ def forward(
510496 attention_mask : Optional [paddle .Tensor ] = None ,
511497 output_attentions : bool = False ,
512498 use_cache : bool = False ,
499+ attn_mask_startend_row_indices : Optional [paddle .Tensor ] = None ,
513500 ** kwargs ,
514501 ) -> Tuple [paddle .Tensor , Optional [paddle .Tensor ], Optional [Tuple [paddle .Tensor ]]]:
515502 """Input shape: Batch x Time x Channel"""
@@ -574,6 +561,7 @@ def forward(
574561 value_states ,
575562 attention_mask ,
576563 output_attentions ,
564+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
577565 training = self .training ,
578566 sequence_parallel = self .sequence_parallel ,
579567 use_reentrant = self .config .recompute_use_reentrant ,
@@ -586,6 +574,7 @@ def forward(
586574 value_states ,
587575 attention_mask ,
588576 output_attentions ,
577+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
589578 training = self .training ,
590579 sequence_parallel = self .sequence_parallel ,
591580 )
@@ -640,6 +629,7 @@ def forward(
640629 output_attentions : Optional [bool ] = False ,
641630 past_key_value : Optional [Tuple [paddle .Tensor ]] = None ,
642631 use_cache : Optional [bool ] = False ,
632+ attn_mask_startend_row_indices : Optional [paddle .Tensor ] = None ,
643633 ** kwargs ,
644634 ) -> Tuple [paddle .Tensor , Optional [Tuple [paddle .Tensor , paddle .Tensor ]]]:
645635 """
@@ -677,6 +667,7 @@ def forward(
677667 attention_mask ,
678668 output_attentions ,
679669 use_cache ,
670+ attn_mask_startend_row_indices ,
680671 use_reentrant = self .config .recompute_use_reentrant ,
681672 )
682673 else :
@@ -687,6 +678,7 @@ def forward(
687678 attention_mask ,
688679 output_attentions ,
689680 use_cache ,
681+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
690682 )
691683
692684 if type (outputs ) is tuple :
@@ -992,6 +984,7 @@ def recompute_training_full(
992984 output_attentions : bool ,
993985 past_key_value : Tensor ,
994986 use_cache : bool ,
987+ attn_mask_startend_row_indices = None ,
995988 ):
996989 def create_custom_forward (module ):
997990 def custom_forward (* inputs ):
@@ -1007,6 +1000,7 @@ def custom_forward(*inputs):
10071000 output_attentions ,
10081001 past_key_value ,
10091002 use_cache ,
1003+ attn_mask_startend_row_indices ,
10101004 use_reentrant = self .config .recompute_use_reentrant ,
10111005 )
10121006
@@ -1023,6 +1017,7 @@ def forward(
10231017 output_attentions : Optional [bool ] = None ,
10241018 output_hidden_states : Optional [bool ] = None ,
10251019 return_dict : Optional [bool ] = None ,
1020+ attn_mask_startend_row_indices = None ,
10261021 ) -> Union [Tuple , BaseModelOutputWithPast ]:
10271022
10281023 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
@@ -1062,20 +1057,24 @@ def forward(
10621057 inputs_embeds = ScatterOp .apply (inputs_embeds )
10631058
10641059 # embed positions
1065- if attention_mask is None :
1060+ if attn_mask_startend_row_indices is not None :
1061+ attention_mask = None
1062+ else :
10661063 # [bs, seq_len]
1067- attention_mask = paddle .ones ((batch_size , seq_length_with_past ), dtype = paddle .bool )
1064+ attention_mask = (
1065+ paddle .ones ((batch_size , seq_length_with_past ), dtype = paddle .bool )
1066+ if attention_mask is None
1067+ else attention_mask
1068+ )
1069+ attention_mask = self ._prepare_decoder_attention_mask (
1070+ attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1071+ ) # [bs, 1, seq_len, seq_len]
1072+ if self .config .use_flash_attention :
1073+ attention_mask = None if is_casual_mask (attention_mask ) else attention_mask
10681074
10691075 if position_ids is None :
10701076 position_ids = paddle .arange (seq_length , dtype = "int64" ).expand ((batch_size , seq_length ))
10711077
1072- attention_mask = self ._prepare_decoder_attention_mask (
1073- attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1074- ) # [bs, 1, seq_len, seq_len]
1075- if self .config .use_flash_attention :
1076- is_casual = is_casual_mask (attention_mask )
1077- if is_casual :
1078- attention_mask = None
10791078 hidden_states = inputs_embeds
10801079
10811080 # decoder layers
@@ -1103,6 +1102,7 @@ def forward(
11031102 output_attentions ,
11041103 past_key_value ,
11051104 use_cache ,
1105+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
11061106 )
11071107 else :
11081108 layer_outputs = decoder_layer (
@@ -1112,6 +1112,7 @@ def forward(
11121112 output_attentions ,
11131113 past_key_value ,
11141114 use_cache ,
1115+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
11151116 )
11161117
11171118 # NOTE: clear outdate cache after it has been used for memory saving
@@ -1340,6 +1341,7 @@ def forward(
13401341 output_attentions : Optional [bool ] = None ,
13411342 output_hidden_states : Optional [bool ] = None ,
13421343 return_dict : Optional [bool ] = None ,
1344+ attn_mask_startend_row_indices = None ,
13431345 ) -> Union [Tuple , CausalLMOutputWithPast ]:
13441346 r"""
13451347 Args:
@@ -1373,6 +1375,13 @@ def forward(
13731375 )
13741376 return_dict = return_dict if return_dict is not None else self .config .use_return_dict
13751377
1378+ if attn_mask_startend_row_indices is not None and attention_mask is not None :
1379+ logger .warning (
1380+ "You have provided both attn_mask_startend_row_indices and attention_mask. "
1381+ "The attn_mask_startend_row_indices will be used."
1382+ )
1383+ attention_mask = None
1384+
13761385 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
13771386 outputs = self .qwen2 (
13781387 input_ids = input_ids ,
@@ -1384,6 +1393,7 @@ def forward(
13841393 output_attentions = output_attentions ,
13851394 output_hidden_states = output_hidden_states ,
13861395 return_dict = return_dict ,
1396+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
13871397 )
13881398
13891399 hidden_states = outputs [0 ]
0 commit comments