Skip to content

Commit d34af01

Browse files
committed
Merge branch 'dev_20241007_qwen2_add_flashmask' of github.com:DrownFish19/PaddleNLP into dev_20241007_qwen2_add_flashmask
2 parents 867ad0b + 0d872f1 commit d34af01

File tree

4 files changed

+134
-66
lines changed

4 files changed

+134
-66
lines changed

README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,19 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩
115115

116116
* 大模型预训练、精调(包含 SFT、PEFT 技术)、对齐、量化已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Mistral 系列、OPT 系列和 Qwen 系列,【LLM】模型预训练、精调、对齐、量化支持列表如下:
117117

118-
| 模型名称/能力支持 | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
119-
|:------------------:|:--------:|:---:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:|
120-
| Llama |||||||||
121-
| Qwen |||||| 🚧 | 🚧 ||
122-
| Mixtral ||| || 🚧 | 🚧 | 🚧 | 🚧 |
123-
| Mistral ||| ||| 🚧 | 🚧 ||
124-
| Baichuan/Baichuan2 |||||| 🚧 |||
125-
| ChatGLM-6B ||| || 🚧 | 🚧 |||
126-
| ChatGLM2/ChatGLM3 ||| || 🚧 | 🚧 |||
127-
| Bloom ||| || 🚧 | 🚧 |||
128-
| GPT-3 ||| 🚧 | 🚧 | 🚧 | 🚧 | 🚧 ||
129-
| OPT ||| | 🚧 | 🚧 | 🚧 | 🚧 ||
130-
| Yuan2 ||| | 🚧 | 🚧 | 🚧 | 🚧 ||
118+
| 模型名称/能力支持 | Pretrain | SFT | FlashMask | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
119+
|:------------------:|:--------:|:---:|:---------:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:|
120+
| Llama ||| | ||||||
121+
| Qwen ||| | ||| 🚧 | 🚧 ||
122+
| Mixtral ||| 🚧 | || 🚧 | 🚧 | 🚧 | 🚧 |
123+
| Mistral ||| 🚧 | ||| 🚧 | 🚧 ||
124+
| Baichuan/Baichuan2 ||| | ||| 🚧 |||
125+
| ChatGLM-6B ||| 🚧 | || 🚧 | 🚧 |||
126+
| ChatGLM2/ChatGLM3 ||| 🚧 | || 🚧 | 🚧 |||
127+
| Bloom ||| 🚧 | || 🚧 | 🚧 |||
128+
| GPT-3 ||| 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 ||
129+
| OPT ||| 🚧 | | 🚧 | 🚧 | 🚧 | 🚧 ||
130+
| Yuan2 ||| 🚧 | | 🚧 | 🚧 | 🚧 | 🚧 ||
131131
------------------------------------------------------------------------------------------
132132

133133
* [大模型推理](./llm/docs/predict/inference.md)已支持 LLaMA 系列、Qwen 系列、Mistral 系列、ChatGLM 系列、Bloom 系列和 Baichuan 系列,支持 Weight Only INT8及 INT4推理,支持 WAC(权重、激活、Cache KV)进行 INT8、FP8量化的推理,【LLM】模型推理支持列表如下:

llm/run_finetune.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
LlamaForCausalLM,
5353
LlamaForCausalLMPipe,
5454
LlamaTokenizer,
55+
Qwen2ForCausalLM,
56+
Qwen2ForCausalLMPipe,
5557
register_sequence_parallel_allreduce_hooks,
5658
)
5759
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
@@ -69,7 +71,7 @@
6971
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
7072
os.environ["USE_CASUAL_MASK"] = "False"
7173

72-
flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
74+
flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe, Qwen2ForCausalLM, Qwen2ForCausalLMPipe]
7375

7476

7577
def main():
@@ -109,6 +111,7 @@ def main():
109111
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
110112
try:
111113
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401
114+
112115
LinearConfig.enable_accumulate_steps_opt()
113116
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
114117
except ImportError:

paddlenlp/transformers/qwen2/modeling.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,15 @@
3737
from ..activations import ACT2FN
3838
from ..conversion_utils import StateDictNameMapping, init_name_mappings
3939
from ..linear_utils import Linear
40+
from ..llama import fusion_ops
4041
from ..model_outputs import (
4142
BaseModelOutputWithPast,
4243
CausalLMOutputWithPast,
4344
SequenceClassifierOutputWithPast,
4445
TokenClassifierOutput,
4546
)
4647
from ..model_utils import PretrainedModel, register_base_model
47-
from ..utils import caculate_llm_flops
48+
from ..utils import caculate_llm_flops, logger
4849
from .configuration import Qwen2Config
4950

5051
try:
@@ -156,6 +157,7 @@ def scaled_dot_product_attention(
156157
value_states,
157158
attention_mask,
158159
output_attentions,
160+
attn_mask_startend_row_indices=None,
159161
training=True,
160162
sequence_parallel=False,
161163
):
@@ -166,32 +168,16 @@ def scaled_dot_product_attention(
166168
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
167169
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
168170

169-
version = paddle.version.full_version
170-
if version != "0.0.0" and version <= "2.5.2":
171-
attn_output, attn_weights = flash_attention(
172-
query_states,
173-
key_states,
174-
value_states,
175-
causal=True,
176-
return_softmax=output_attentions,
177-
)
178-
else:
179-
attn_output = F.scaled_dot_product_attention(
180-
query_states,
181-
key_states,
182-
value_states,
183-
attn_mask=attention_mask,
184-
is_causal=attention_mask is None,
185-
dropout_p=config.attention_dropout if training else 0.0,
186-
training=training,
187-
)
188-
attn_weights = None
189-
190-
if sequence_parallel:
191-
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
192-
else:
193-
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
194-
return (attn_output, attn_weights) if output_attentions else attn_output
171+
return fusion_ops.fusion_flash_attention(
172+
query_states,
173+
config,
174+
key_states,
175+
value_states,
176+
attention_mask,
177+
output_attentions,
178+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
179+
sequence_parallel=sequence_parallel,
180+
)
195181
else:
196182
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
197183
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
@@ -510,6 +496,7 @@ def forward(
510496
attention_mask: Optional[paddle.Tensor] = None,
511497
output_attentions: bool = False,
512498
use_cache: bool = False,
499+
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
513500
**kwargs,
514501
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
515502
"""Input shape: Batch x Time x Channel"""
@@ -574,6 +561,7 @@ def forward(
574561
value_states,
575562
attention_mask,
576563
output_attentions,
564+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
577565
training=self.training,
578566
sequence_parallel=self.sequence_parallel,
579567
use_reentrant=self.config.recompute_use_reentrant,
@@ -586,6 +574,7 @@ def forward(
586574
value_states,
587575
attention_mask,
588576
output_attentions,
577+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
589578
training=self.training,
590579
sequence_parallel=self.sequence_parallel,
591580
)
@@ -640,6 +629,7 @@ def forward(
640629
output_attentions: Optional[bool] = False,
641630
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
642631
use_cache: Optional[bool] = False,
632+
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
643633
**kwargs,
644634
) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
645635
"""
@@ -677,6 +667,7 @@ def forward(
677667
attention_mask,
678668
output_attentions,
679669
use_cache,
670+
attn_mask_startend_row_indices,
680671
use_reentrant=self.config.recompute_use_reentrant,
681672
)
682673
else:
@@ -687,6 +678,7 @@ def forward(
687678
attention_mask,
688679
output_attentions,
689680
use_cache,
681+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
690682
)
691683

692684
if type(outputs) is tuple:
@@ -992,6 +984,7 @@ def recompute_training_full(
992984
output_attentions: bool,
993985
past_key_value: Tensor,
994986
use_cache: bool,
987+
attn_mask_startend_row_indices=None,
995988
):
996989
def create_custom_forward(module):
997990
def custom_forward(*inputs):
@@ -1007,6 +1000,7 @@ def custom_forward(*inputs):
10071000
output_attentions,
10081001
past_key_value,
10091002
use_cache,
1003+
attn_mask_startend_row_indices,
10101004
use_reentrant=self.config.recompute_use_reentrant,
10111005
)
10121006

@@ -1023,6 +1017,7 @@ def forward(
10231017
output_attentions: Optional[bool] = None,
10241018
output_hidden_states: Optional[bool] = None,
10251019
return_dict: Optional[bool] = None,
1020+
attn_mask_startend_row_indices=None,
10261021
) -> Union[Tuple, BaseModelOutputWithPast]:
10271022

10281023
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1062,20 +1057,24 @@ def forward(
10621057
inputs_embeds = ScatterOp.apply(inputs_embeds)
10631058

10641059
# embed positions
1065-
if attention_mask is None:
1060+
if attn_mask_startend_row_indices is not None:
1061+
attention_mask = None
1062+
else:
10661063
# [bs, seq_len]
1067-
attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
1064+
attention_mask = (
1065+
paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
1066+
if attention_mask is None
1067+
else attention_mask
1068+
)
1069+
attention_mask = self._prepare_decoder_attention_mask(
1070+
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
1071+
) # [bs, 1, seq_len, seq_len]
1072+
if self.config.use_flash_attention:
1073+
attention_mask = None if is_casual_mask(attention_mask) else attention_mask
10681074

10691075
if position_ids is None:
10701076
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
10711077

1072-
attention_mask = self._prepare_decoder_attention_mask(
1073-
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
1074-
) # [bs, 1, seq_len, seq_len]
1075-
if self.config.use_flash_attention:
1076-
is_casual = is_casual_mask(attention_mask)
1077-
if is_casual:
1078-
attention_mask = None
10791078
hidden_states = inputs_embeds
10801079

10811080
# decoder layers
@@ -1103,6 +1102,7 @@ def forward(
11031102
output_attentions,
11041103
past_key_value,
11051104
use_cache,
1105+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
11061106
)
11071107
else:
11081108
layer_outputs = decoder_layer(
@@ -1112,6 +1112,7 @@ def forward(
11121112
output_attentions,
11131113
past_key_value,
11141114
use_cache,
1115+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
11151116
)
11161117

11171118
# NOTE: clear outdate cache after it has been used for memory saving
@@ -1340,6 +1341,7 @@ def forward(
13401341
output_attentions: Optional[bool] = None,
13411342
output_hidden_states: Optional[bool] = None,
13421343
return_dict: Optional[bool] = None,
1344+
attn_mask_startend_row_indices=None,
13431345
) -> Union[Tuple, CausalLMOutputWithPast]:
13441346
r"""
13451347
Args:
@@ -1373,6 +1375,13 @@ def forward(
13731375
)
13741376
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
13751377

1378+
if attn_mask_startend_row_indices is not None and attention_mask is not None:
1379+
logger.warning(
1380+
"You have provided both attn_mask_startend_row_indices and attention_mask. "
1381+
"The attn_mask_startend_row_indices will be used."
1382+
)
1383+
attention_mask = None
1384+
13761385
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
13771386
outputs = self.qwen2(
13781387
input_ids=input_ids,
@@ -1384,6 +1393,7 @@ def forward(
13841393
output_attentions=output_attentions,
13851394
output_hidden_states=output_hidden_states,
13861395
return_dict=return_dict,
1396+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
13871397
)
13881398

13891399
hidden_states = outputs[0]

0 commit comments

Comments
 (0)