Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 38 additions & 51 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,43 @@ def main():
weight_double_quant_block_size=model_args.weight_double_quant_block_size,
)

model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
from_aistudio=model_args.from_aistudio,
quantization_config=quantization_config,
)
if hasattr(model_config, "use_flash_attention"):
model_config.use_flash_attention = model_args.use_flash_attention

model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
model_config.recompute_granularity = model_args.recompute_granularity
model_config.virtual_pp_degree = model_args.virtual_pp_degree
model_config.sequence_parallel = model_args.sequence_parallel
model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
model_config.use_fused_rope = model_args.use_fused_rope

model_config.no_recompute_layers = model_args.no_recompute_layers
model_config.pp_recompute_interval = model_args.pp_recompute_interval
model_config.recompute_use_reentrant = model_args.recompute_use_reentrant
model_config.use_recompute = training_args.recompute

model_config.tensor_parallel_degree = training_args.tensor_parallel_degree
model_config.tensor_parallel_rank = training_args.tensor_parallel_rank

# Config for model using dropout, such as GPT.
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob

model_config.sep_parallel_degree = training_args.sep_parallel_degree
model_config.tensor_parallel_output = training_args.tensor_parallel_output
model_config.seq_length = data_args.max_length

if training_args.pipeline_parallel_degree > 1:
if data_args.eval_with_do_generation and training_args.do_eval:
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
Expand All @@ -145,63 +182,13 @@ def main():
if not training_args.autotuner_benchmark:
model = AutoModelForCausalLMPipe.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
use_flash_attention=model_args.use_flash_attention,
dtype=dtype,
config=model_config,
from_aistudio=model_args.from_aistudio,
quantization_config=quantization_config,
)
else:
# NOTE(gongenlei): new add autotuner_benchmark
model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
from_aistudio=model_args.from_aistudio,
quantization_config=quantization_config,
)
model = AutoModelForCausalLMPipe.from_config(model_config, dtype=dtype)
else:
model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
from_aistudio=model_args.from_aistudio,
quantization_config=quantization_config,
)
if hasattr(model_config, "use_flash_attention"):
model_config.use_flash_attention = model_args.use_flash_attention

model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
model_config.recompute_granularity = model_args.recompute_granularity
model_config.virtual_pp_degree = model_args.virtual_pp_degree
model_config.sequence_parallel = model_args.sequence_parallel
model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
model_config.use_fused_rope = model_args.use_fused_rope

model_config.no_recompute_layers = model_args.no_recompute_layers
model_config.pp_recompute_interval = model_args.pp_recompute_interval
model_config.recompute_use_reentrant = model_args.recompute_use_reentrant
model_config.use_recompute = training_args.recompute

model_config.tensor_parallel_degree = training_args.tensor_parallel_degree
model_config.tensor_parallel_rank = training_args.tensor_parallel_rank

# Config for model using dropout, such as GPT.
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob

model_config.sep_parallel_degree = training_args.sep_parallel_degree
model_config.tensor_parallel_output = True
model_config.seq_length = data_args.max_length
if not training_args.autotuner_benchmark:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
Expand Down
12 changes: 6 additions & 6 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,9 +1381,9 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
input_shape, past_key_values_length=past_key_values_length
)
if get_env_device() == "npu":
expanded_attn_mask = expanded_attn_mask.astype("bool")
combined_attention_mask = combined_attention_mask.astype("bool")
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
else:
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
elif len(attention_mask.shape) == 3:
expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
Expand All @@ -1394,9 +1394,9 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
if get_env_device() == "npu":
x = paddle.to_tensor(0.0, dtype="float16")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
expanded_attn_mask = expanded_attn_mask.astype("float16")
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = expanded_attn_mask.astype("float32")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype=dtype)
Expand Down
6 changes: 6 additions & 0 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.utils.tools import get_env_device

from .modeling import (
LlamaConfig,
Expand Down Expand Up @@ -153,6 +154,11 @@ def forward(self, args):
attention_mask, (batch_size, seq_length), 0, input_embeds.dtype
)
attention_mask.stop_gradient = True
if get_env_device() == "npu":
attention_mask = attention_mask.astype("bool")
elif get_env_device() == "npu":
attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool"))
attention_mask.stop_gradient = True

if self.config.alibi and attention_mask is None:
attention_mask = LlamaModel._prepare_decoder_attention_mask(
Expand Down