Skip to content

Commit 092029f

Browse files
committed
fix npu sft ckpt load bug and no FA bug
1 parent 05acad5 commit 092029f

File tree

3 files changed

+50
-57
lines changed

3 files changed

+50
-57
lines changed

llm/finetune_generation.py

Lines changed: 38 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,43 @@ def main():
137137
weight_double_quant_block_size=model_args.weight_double_quant_block_size,
138138
)
139139

140+
model_config = AutoConfig.from_pretrained(
141+
model_args.model_name_or_path,
142+
tensor_parallel_output=training_args.tensor_parallel_output,
143+
tensor_parallel_degree=training_args.tensor_parallel_degree,
144+
tensor_parallel_rank=training_args.tensor_parallel_rank,
145+
dtype=dtype,
146+
from_aistudio=model_args.from_aistudio,
147+
quantization_config=quantization_config,
148+
)
149+
if hasattr(model_config, "use_flash_attention"):
150+
model_config.use_flash_attention = model_args.use_flash_attention
151+
152+
model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
153+
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
154+
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
155+
model_config.recompute_granularity = model_args.recompute_granularity
156+
model_config.virtual_pp_degree = model_args.virtual_pp_degree
157+
model_config.sequence_parallel = model_args.sequence_parallel
158+
model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
159+
model_config.use_fused_rope = model_args.use_fused_rope
160+
161+
model_config.no_recompute_layers = model_args.no_recompute_layers
162+
model_config.pp_recompute_interval = model_args.pp_recompute_interval
163+
model_config.recompute_use_reentrant = model_args.recompute_use_reentrant
164+
model_config.use_recompute = training_args.recompute
165+
166+
model_config.tensor_parallel_degree = training_args.tensor_parallel_degree
167+
model_config.tensor_parallel_rank = training_args.tensor_parallel_rank
168+
169+
# Config for model using dropout, such as GPT.
170+
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
171+
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
172+
173+
model_config.sep_parallel_degree = training_args.sep_parallel_degree
174+
model_config.tensor_parallel_output = True
175+
model_config.seq_length = data_args.max_length
176+
140177
if training_args.pipeline_parallel_degree > 1:
141178
if data_args.eval_with_do_generation and training_args.do_eval:
142179
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
@@ -145,63 +182,13 @@ def main():
145182
if not training_args.autotuner_benchmark:
146183
model = AutoModelForCausalLMPipe.from_pretrained(
147184
model_args.model_name_or_path,
148-
tensor_parallel_output=training_args.tensor_parallel_output,
149-
tensor_parallel_degree=training_args.tensor_parallel_degree,
150-
tensor_parallel_rank=training_args.tensor_parallel_rank,
151-
use_flash_attention=model_args.use_flash_attention,
152-
dtype=dtype,
185+
config=model_config,
153186
from_aistudio=model_args.from_aistudio,
154-
quantization_config=quantization_config,
155187
)
156188
else:
157189
# NOTE(gongenlei): new add autotuner_benchmark
158-
model_config = AutoConfig.from_pretrained(
159-
model_args.model_name_or_path,
160-
tensor_parallel_output=training_args.tensor_parallel_output,
161-
tensor_parallel_degree=training_args.tensor_parallel_degree,
162-
tensor_parallel_rank=training_args.tensor_parallel_rank,
163-
dtype=dtype,
164-
from_aistudio=model_args.from_aistudio,
165-
quantization_config=quantization_config,
166-
)
167190
model = AutoModelForCausalLMPipe.from_config(model_config, dtype=dtype)
168191
else:
169-
model_config = AutoConfig.from_pretrained(
170-
model_args.model_name_or_path,
171-
tensor_parallel_output=training_args.tensor_parallel_output,
172-
tensor_parallel_degree=training_args.tensor_parallel_degree,
173-
tensor_parallel_rank=training_args.tensor_parallel_rank,
174-
dtype=dtype,
175-
from_aistudio=model_args.from_aistudio,
176-
quantization_config=quantization_config,
177-
)
178-
if hasattr(model_config, "use_flash_attention"):
179-
model_config.use_flash_attention = model_args.use_flash_attention
180-
181-
model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
182-
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
183-
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
184-
model_config.recompute_granularity = model_args.recompute_granularity
185-
model_config.virtual_pp_degree = model_args.virtual_pp_degree
186-
model_config.sequence_parallel = model_args.sequence_parallel
187-
model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
188-
model_config.use_fused_rope = model_args.use_fused_rope
189-
190-
model_config.no_recompute_layers = model_args.no_recompute_layers
191-
model_config.pp_recompute_interval = model_args.pp_recompute_interval
192-
model_config.recompute_use_reentrant = model_args.recompute_use_reentrant
193-
model_config.use_recompute = training_args.recompute
194-
195-
model_config.tensor_parallel_degree = training_args.tensor_parallel_degree
196-
model_config.tensor_parallel_rank = training_args.tensor_parallel_rank
197-
198-
# Config for model using dropout, such as GPT.
199-
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
200-
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
201-
202-
model_config.sep_parallel_degree = training_args.sep_parallel_degree
203-
model_config.tensor_parallel_output = True
204-
model_config.seq_length = data_args.max_length
205192
if not training_args.autotuner_benchmark:
206193
model = AutoModelForCausalLM.from_pretrained(
207194
model_args.model_name_or_path,

paddlenlp/transformers/llama/modeling.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,9 +1381,9 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
13811381
input_shape, past_key_values_length=past_key_values_length
13821382
)
13831383
if get_env_device() == "npu":
1384-
expanded_attn_mask = expanded_attn_mask.astype("bool")
1385-
combined_attention_mask = combined_attention_mask.astype("bool")
1386-
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
1384+
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
1385+
else:
1386+
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
13871387
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
13881388
elif len(attention_mask.shape) == 3:
13891389
expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
@@ -1394,9 +1394,9 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
13941394
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
13951395
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
13961396
if get_env_device() == "npu":
1397-
x = paddle.to_tensor(0.0, dtype="float16")
1398-
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
1399-
expanded_attn_mask = expanded_attn_mask.astype("float16")
1397+
x = paddle.to_tensor(0.0, dtype="float32")
1398+
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
1399+
expanded_attn_mask = expanded_attn_mask.astype("float32")
14001400
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
14011401
elif get_env_device() == "xpu":
14021402
x = paddle.to_tensor(0.0, dtype=dtype)

paddlenlp/transformers/llama/modeling_pp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from paddle.distributed.fleet.utils import recompute
2020

2121
from paddlenlp.transformers.model_utils import PipelinePretrainedModel
22+
from paddlenlp.utils.tools import get_env_device
2223

2324
from .modeling import (
2425
LlamaConfig,
@@ -153,6 +154,11 @@ def forward(self, args):
153154
attention_mask, (batch_size, seq_length), 0, input_embeds.dtype
154155
)
155156
attention_mask.stop_gradient = True
157+
if get_env_device() == "npu":
158+
attention_mask = attention_mask.astype("bool")
159+
elif get_env_device() == "npu":
160+
attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool"))
161+
attention_mask.stop_gradient = True
156162

157163
if self.config.alibi and attention_mask is None:
158164
attention_mask = LlamaModel._prepare_decoder_attention_mask(

0 commit comments

Comments
 (0)