Skip to content

Commit 0d872f1

Browse files
committed
add flashmask in modeling_pp
1 parent 90f1c66 commit 0d872f1

File tree

1 file changed

+71
-16
lines changed

1 file changed

+71
-16
lines changed

paddlenlp/transformers/qwen2/modeling_pp.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
16+
from typing import OrderedDict
17+
1518
import paddle
1619
import paddle.distributed.fleet as fleet
1720
import paddle.nn as nn
@@ -41,32 +44,37 @@
4144

4245
def parse_args(args):
4346
if isinstance(args, tuple):
44-
if len(args) == 3:
45-
hidden_states, attention_mask, position_ids = args
47+
if len(args) == 4:
48+
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args
49+
elif len(args) == 3:
50+
hidden_states, attention_mask, attn_mask_startend_row_indices = args
51+
position_ids = None
4652
elif len(args) == 2:
4753
hidden_states, attention_mask = args
48-
position_ids = None
49-
elif len(args) == 1:
50-
hidden_states = args
51-
attention_mask, position_ids = None, None
54+
attn_mask_startend_row_indices, position_ids = None, None
5255
else:
5356
hidden_states = args
54-
attention_mask, position_ids = None, None
57+
attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None
5558

5659
if position_ids is not None:
5760
position_ids.stop_gradient = True
5861

5962
if attention_mask is not None:
6063
attention_mask.stop_gradient = True
6164

62-
return hidden_states, attention_mask, position_ids
65+
if attn_mask_startend_row_indices is not None:
66+
attn_mask_startend_row_indices.stop_gradient = True
67+
68+
return hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids
6369

6470

65-
def return_args(hidden_states, attention_mask=None, position_ids=None):
71+
def return_args(hidden_states, attention_mask=None, attn_mask_startend_row_indices=None, position_ids=None):
6672
ret = (hidden_states,)
6773

6874
if attention_mask is not None:
6975
ret += (attention_mask.clone(),)
76+
if attn_mask_startend_row_indices is not None:
77+
ret += (attn_mask_startend_row_indices.clone(),)
7078
if position_ids is not None:
7179
ret += (position_ids.clone(),)
7280
if len(ret) == 1:
@@ -112,7 +120,7 @@ def forward(self, args):
112120
Returns:
113121
_type_: _description_
114122
"""
115-
input_ids, attention_mask, position_ids = parse_args(args)
123+
input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
116124
input_embeds = self.embed_tokens(input_ids)
117125
if self.config.sequence_parallel:
118126
from paddlenlp.transformers import ScatterOp
@@ -126,6 +134,10 @@ def forward(self, args):
126134
batch_size, seq_length = input_ids.shape
127135

128136
if attention_mask is not None:
137+
assert (
138+
attn_mask_startend_row_indices is None
139+
), "attention_mask and attn_mask_startend_row_indices can not be set at same time"
140+
129141
attention_mask = Qwen2Model._prepare_decoder_attention_mask(
130142
attention_mask, (batch_size, seq_length), 0, input_embeds.dtype
131143
)
@@ -136,22 +148,34 @@ def forward(self, args):
136148
attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool"))
137149
attention_mask.stop_gradient = True
138150

139-
return return_args(input_embeds, attention_mask, position_ids)
151+
return return_args(input_embeds, attention_mask, attn_mask_startend_row_indices, position_ids)
140152

141153

142154
class Qwen2DecoderLayerPipe(Qwen2DecoderLayer):
143155
def forward(self, args):
144-
hidden_states, attention_mask, position_ids = parse_args(args)
156+
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
145157

146158
has_gradient = not hidden_states.stop_gradient
147159

160+
if attention_mask is not None and attention_mask.dtype == paddle.int32:
161+
attention_mask, attn_mask_startend_row_indices, position_ids = (
162+
None,
163+
attention_mask,
164+
attn_mask_startend_row_indices,
165+
)
166+
elif attention_mask is not None and attention_mask.dtype == paddle.int64:
167+
attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask
168+
elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64:
169+
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices
170+
148171
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
149-
if attention_mask is not None:
172+
if attention_mask is not None or attn_mask_startend_row_indices is not None:
150173
hidden_states = recompute(
151174
super().forward,
152175
hidden_states,
153176
position_ids=position_ids,
154177
attention_mask=attention_mask,
178+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
155179
use_reentrant=False,
156180
)
157181
else:
@@ -160,12 +184,18 @@ def forward(self, args):
160184
super().forward,
161185
hidden_states,
162186
position_ids=position_ids,
187+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
163188
use_reentrant=self.config.recompute_use_reentrant,
164189
)
165190
else:
166-
hidden_states = super().forward(hidden_states, position_ids=position_ids, attention_mask=attention_mask)
191+
hidden_states = super().forward(
192+
hidden_states,
193+
position_ids=position_ids,
194+
attention_mask=attention_mask,
195+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
196+
)
167197

168-
return return_args(hidden_states, attention_mask, position_ids)
198+
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)
169199

170200

171201
class Qwen2RMSNormPipe(nn.Layer):
@@ -174,7 +204,7 @@ def __init__(self, config):
174204
self.norm = Qwen2RMSNorm(config)
175205

176206
def forward(self, args):
177-
hidden_states, attention_mask, position_ids = parse_args(args)
207+
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
178208
return self.norm(hidden_states)
179209

180210

@@ -202,6 +232,31 @@ class Qwen2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
202232

203233
# DONOT Add base_model_prefix !!!!
204234

235+
@classmethod
236+
def _prepare_pipeline_inputs_func(cls, inputs):
237+
238+
first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"]
239+
last_stage_keys = ["labels"]
240+
241+
def get_expected_keys(inputs, keys):
242+
ret = tuple([inputs.pop(k) if k in inputs else None for k in keys])
243+
if len(ret) == 1:
244+
ret = ret[0]
245+
return ret
246+
247+
if type(inputs) is dict or type(inputs) is OrderedDict:
248+
return [
249+
get_expected_keys(inputs, first_stage_keys),
250+
get_expected_keys(inputs, last_stage_keys),
251+
]
252+
253+
keys = list(inputs[0].keys())
254+
inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys}
255+
return [
256+
get_expected_keys(inputs_batch, first_stage_keys),
257+
get_expected_keys(inputs_batch, last_stage_keys),
258+
]
259+
205260
def __init__(self, config: Qwen2Config):
206261
self.config = config
207262

0 commit comments

Comments
 (0)