1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+
16+ from typing import OrderedDict
17+
1518import paddle
1619import paddle .distributed .fleet as fleet
1720import paddle .nn as nn
4144
4245def parse_args (args ):
4346 if isinstance (args , tuple ):
44- if len (args ) == 3 :
45- hidden_states , attention_mask , position_ids = args
47+ if len (args ) == 4 :
48+ hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = args
49+ elif len (args ) == 3 :
50+ hidden_states , attention_mask , attn_mask_startend_row_indices = args
51+ position_ids = None
4652 elif len (args ) == 2 :
4753 hidden_states , attention_mask = args
48- position_ids = None
49- elif len (args ) == 1 :
50- hidden_states = args
51- attention_mask , position_ids = None , None
54+ attn_mask_startend_row_indices , position_ids = None , None
5255 else :
5356 hidden_states = args
54- attention_mask , position_ids = None , None
57+ attention_mask , attn_mask_startend_row_indices , position_ids = None , None , None
5558
5659 if position_ids is not None :
5760 position_ids .stop_gradient = True
5861
5962 if attention_mask is not None :
6063 attention_mask .stop_gradient = True
6164
62- return hidden_states , attention_mask , position_ids
65+ if attn_mask_startend_row_indices is not None :
66+ attn_mask_startend_row_indices .stop_gradient = True
67+
68+ return hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids
6369
6470
65- def return_args (hidden_states , attention_mask = None , position_ids = None ):
71+ def return_args (hidden_states , attention_mask = None , attn_mask_startend_row_indices = None , position_ids = None ):
6672 ret = (hidden_states ,)
6773
6874 if attention_mask is not None :
6975 ret += (attention_mask .clone (),)
76+ if attn_mask_startend_row_indices is not None :
77+ ret += (attn_mask_startend_row_indices .clone (),)
7078 if position_ids is not None :
7179 ret += (position_ids .clone (),)
7280 if len (ret ) == 1 :
@@ -112,7 +120,7 @@ def forward(self, args):
112120 Returns:
113121 _type_: _description_
114122 """
115- input_ids , attention_mask , position_ids = parse_args (args )
123+ input_ids , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
116124 input_embeds = self .embed_tokens (input_ids )
117125 if self .config .sequence_parallel :
118126 from paddlenlp .transformers import ScatterOp
@@ -126,6 +134,10 @@ def forward(self, args):
126134 batch_size , seq_length = input_ids .shape
127135
128136 if attention_mask is not None :
137+ assert (
138+ attn_mask_startend_row_indices is None
139+ ), "attention_mask and attn_mask_startend_row_indices can not be set at same time"
140+
129141 attention_mask = Qwen2Model ._prepare_decoder_attention_mask (
130142 attention_mask , (batch_size , seq_length ), 0 , input_embeds .dtype
131143 )
@@ -136,22 +148,34 @@ def forward(self, args):
136148 attention_mask = paddle .tril (paddle .ones ((seq_length , seq_length ), dtype = "bool" ))
137149 attention_mask .stop_gradient = True
138150
139- return return_args (input_embeds , attention_mask , position_ids )
151+ return return_args (input_embeds , attention_mask , attn_mask_startend_row_indices , position_ids )
140152
141153
142154class Qwen2DecoderLayerPipe (Qwen2DecoderLayer ):
143155 def forward (self , args ):
144- hidden_states , attention_mask , position_ids = parse_args (args )
156+ hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
145157
146158 has_gradient = not hidden_states .stop_gradient
147159
160+ if attention_mask is not None and attention_mask .dtype == paddle .int32 :
161+ attention_mask , attn_mask_startend_row_indices , position_ids = (
162+ None ,
163+ attention_mask ,
164+ attn_mask_startend_row_indices ,
165+ )
166+ elif attention_mask is not None and attention_mask .dtype == paddle .int64 :
167+ attention_mask , attn_mask_startend_row_indices , position_ids = None , None , attention_mask
168+ elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices .dtype == paddle .int64 :
169+ attn_mask_startend_row_indices , position_ids = None , attn_mask_startend_row_indices
170+
148171 if self .enable_recompute and self .config .recompute_granularity == "full" and has_gradient :
149- if attention_mask is not None :
172+ if attention_mask is not None or attn_mask_startend_row_indices is not None :
150173 hidden_states = recompute (
151174 super ().forward ,
152175 hidden_states ,
153176 position_ids = position_ids ,
154177 attention_mask = attention_mask ,
178+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
155179 use_reentrant = False ,
156180 )
157181 else :
@@ -160,12 +184,18 @@ def forward(self, args):
160184 super ().forward ,
161185 hidden_states ,
162186 position_ids = position_ids ,
187+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
163188 use_reentrant = self .config .recompute_use_reentrant ,
164189 )
165190 else :
166- hidden_states = super ().forward (hidden_states , position_ids = position_ids , attention_mask = attention_mask )
191+ hidden_states = super ().forward (
192+ hidden_states ,
193+ position_ids = position_ids ,
194+ attention_mask = attention_mask ,
195+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
196+ )
167197
168- return return_args (hidden_states , attention_mask , position_ids )
198+ return return_args (hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids )
169199
170200
171201class Qwen2RMSNormPipe (nn .Layer ):
@@ -174,7 +204,7 @@ def __init__(self, config):
174204 self .norm = Qwen2RMSNorm (config )
175205
176206 def forward (self , args ):
177- hidden_states , attention_mask , position_ids = parse_args (args )
207+ hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
178208 return self .norm (hidden_states )
179209
180210
@@ -202,6 +232,31 @@ class Qwen2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
202232
203233 # DONOT Add base_model_prefix !!!!
204234
235+ @classmethod
236+ def _prepare_pipeline_inputs_func (cls , inputs ):
237+
238+ first_stage_keys = ["input_ids" , "attention_mask" , "attn_mask_startend_row_indices" , "position_ids" ]
239+ last_stage_keys = ["labels" ]
240+
241+ def get_expected_keys (inputs , keys ):
242+ ret = tuple ([inputs .pop (k ) if k in inputs else None for k in keys ])
243+ if len (ret ) == 1 :
244+ ret = ret [0 ]
245+ return ret
246+
247+ if type (inputs ) is dict or type (inputs ) is OrderedDict :
248+ return [
249+ get_expected_keys (inputs , first_stage_keys ),
250+ get_expected_keys (inputs , last_stage_keys ),
251+ ]
252+
253+ keys = list (inputs [0 ].keys ())
254+ inputs_batch = {key : [data .pop (key ) for data in inputs ] for key in keys }
255+ return [
256+ get_expected_keys (inputs_batch , first_stage_keys ),
257+ get_expected_keys (inputs_batch , last_stage_keys ),
258+ ]
259+
205260 def __init__ (self , config : Qwen2Config ):
206261 self .config = config
207262
0 commit comments