Skip to content

Commit ffc385f

Browse files
committed
[Trainer] update sequence parallel (PaddlePaddle#9757)
* update emb doc * update register_sequence_parallel_allreduce_hooks * update fuse_sequence_parallel_allreduce
1 parent ef6a491 commit ffc385f

File tree

8 files changed

+34
-47
lines changed

8 files changed

+34
-47
lines changed

llm/alignment/dpo/run_dpo.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
AutoTokenizer,
4242
LlamaForCausalLM,
4343
LlamaForCausalLMPipe,
44-
register_sequence_parallel_allreduce_hooks,
4544
)
4645
from paddlenlp.trl import (
4746
DPOTrainer,
@@ -147,10 +146,6 @@ def main():
147146
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
148147
raise NotImplementedError(f"{model.__class__} not support flash mask.")
149148

150-
if model_args.sequence_parallel:
151-
register_sequence_parallel_allreduce_hooks(
152-
model, training_args.gradient_accumulation_steps, model_args.fuse_sequence_parallel_allreduce
153-
)
154149
if model_args.tokenizer_name_or_path is not None:
155150
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
156151
else:

llm/auto_parallel/gpt-3/run_pretrain_auto.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,6 @@ class ModelArguments:
236236
hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."})
237237
attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."})
238238

239-
sequence_parallel: bool = field(
240-
default=False,
241-
metadata={"help": "whether to use sequence parallel"},
242-
)
243-
fuse_sequence_parallel_allreduce: bool = field(
244-
default=False,
245-
metadata={"help": "whether to use fuse sequence parallel allreduce"},
246-
)
247239
use_fused_rope: Optional[bool] = field(
248240
default=False,
249241
metadata={"help": "Enable rope fusion or not."},
@@ -512,8 +504,8 @@ def main():
512504
config.fuse_attention_ffn = model_args.fuse_attention_ffn
513505
config.recompute_granularity = model_args.recompute_granularity
514506
config.virtual_pp_degree = model_args.virtual_pp_degree
515-
config.sequence_parallel = model_args.sequence_parallel
516-
config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
507+
config.sequence_parallel = training_args.sequence_parallel
508+
config.fuse_sequence_parallel_allreduce = training_args.fuse_sequence_parallel_allreduce
517509
config.use_fused_rope = model_args.use_fused_rope
518510
config.no_recompute_layers = model_args.no_recompute_layers
519511
config.pp_recompute_interval = model_args.pp_recompute_interval

llm/auto_parallel/llama/run_pretrain_auto.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,6 @@ class ModelArguments:
255255
"help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models."
256256
},
257257
)
258-
sequence_parallel: bool = field(
259-
default=False,
260-
metadata={"help": "whether to use sequence parallel"},
261-
)
262-
fuse_sequence_parallel_allreduce: bool = field(
263-
default=False,
264-
metadata={"help": "whether to use fuse sequence parallel allreduce"},
265-
)
266258
use_fused_rope: Optional[bool] = field(
267259
default=False,
268260
metadata={"help": "Enable rope fusion or not."},
@@ -568,8 +560,8 @@ def main():
568560
config.fuse_attention_ffn = model_args.fuse_attention_ffn
569561
config.recompute_granularity = model_args.recompute_granularity
570562
config.virtual_pp_degree = model_args.virtual_pp_degree
571-
config.sequence_parallel = model_args.sequence_parallel
572-
config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
563+
config.sequence_parallel = training_args.sequence_parallel
564+
config.fuse_sequence_parallel_allreduce = training_args.fuse_sequence_parallel_allreduce
573565
config.use_fused_rope = model_args.use_fused_rope
574566
config.no_recompute_layers = model_args.no_recompute_layers
575567
config.pp_recompute_interval = model_args.pp_recompute_interval

llm/auto_parallel/qwen/run_pretrain_3D_auto.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,14 +239,6 @@ class ModelArguments:
239239
"help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models."
240240
},
241241
)
242-
sequence_parallel: bool = field(
243-
default=False,
244-
metadata={"help": "whether to use sequence parallel"},
245-
)
246-
fuse_sequence_parallel_allreduce: bool = field(
247-
default=False,
248-
metadata={"help": "whether to use fuse sequence parallel allreduce"},
249-
)
250242
use_fused_rope: Optional[bool] = field(
251243
default=False,
252244
metadata={"help": "Enable rope fusion or not."},
@@ -524,8 +516,8 @@ def main():
524516
config.fuse_attention_ffn = model_args.fuse_attention_ffn
525517
config.recompute_granularity = model_args.recompute_granularity
526518
config.virtual_pp_degree = model_args.virtual_pp_degree
527-
config.sequence_parallel = model_args.sequence_parallel
528-
config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
519+
config.sequence_parallel = training_args.sequence_parallel
520+
config.fuse_sequence_parallel_allreduce = training_args.fuse_sequence_parallel_allreduce
529521
config.use_fused_rope = model_args.use_fused_rope
530522
config.no_recompute_layers = model_args.no_recompute_layers
531523
config.pp_recompute_interval = model_args.pp_recompute_interval

llm/run_finetune.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
LlamaForCausalLM,
5353
LlamaForCausalLMPipe,
5454
LlamaTokenizer,
55-
register_sequence_parallel_allreduce_hooks,
5655
)
5756
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
5857
from paddlenlp.utils.llm_utils import (
@@ -110,6 +109,7 @@ def main():
110109
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
111110
try:
112111
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401
112+
113113
LinearConfig.enable_accumulate_steps_opt()
114114
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
115115
except ImportError:
@@ -202,10 +202,7 @@ def neft_post_hook(module, input, output):
202202
neft_post_hook_handle = model.get_input_embeddings().register_forward_post_hook(neft_post_hook)
203203
else:
204204
raise NotImplementedError("Only support neftune for model with get_input_embeddings")
205-
if training_args.sequence_parallel:
206-
register_sequence_parallel_allreduce_hooks(
207-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
208-
)
205+
209206
# Load tokenizer & dataset
210207
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio)
211208
# init chat_template for tokenizer

llm/run_pretrain.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
AutoTokenizer,
4242
CosineAnnealingWithWarmupDecay,
4343
LinearAnnealingWithWarmupDecay,
44-
register_sequence_parallel_allreduce_hooks,
4544
)
4645
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
4746
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
@@ -500,11 +499,6 @@ def main():
500499
else:
501500
model = model_class.from_config(config, dtype=dtype)
502501

503-
if training_args.sequence_parallel:
504-
register_sequence_parallel_allreduce_hooks(
505-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
506-
)
507-
508502
if training_args.recompute:
509503
model.recompute_enable()
510504

paddlenlp/trainer/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@
8787
from ..quantization.quantization_linear import QuantizationLinear
8888
except:
8989
QuantizationLinear = None
90+
try:
91+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
92+
register_sequence_parallel_allreduce_hooks,
93+
)
94+
except:
95+
pass
9096
from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance
9197
from ..transformers.model_utils import (
9298
PretrainedModel,
@@ -408,6 +414,11 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
408414
"We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config."
409415
)
410416

417+
if args.sequence_parallel:
418+
register_sequence_parallel_allreduce_hooks(
419+
self.model, args.gradient_accumulation_steps, args.fuse_sequence_parallel_allreduce
420+
)
421+
411422
self.do_grad_scaling = False
412423
self.enable_autocast_context_manager = False
413424
if args.fp16 or args.bf16:

paddlenlp/trainer/training_args.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,13 @@ class TrainingArguments:
632632
)
633633
},
634634
)
635+
sequence_parallel: bool = field(
636+
default=False,
637+
metadata={"help": "Whether to enable sequence parallel."},
638+
)
639+
fuse_sequence_parallel_allreduce: bool = field(
640+
default=False, metadata={"help": "Whether to use fuse sequence parallel allreduce."}
641+
)
635642
sequence_parallel_config: str = field(
636643
default="",
637644
metadata={
@@ -1116,10 +1123,17 @@ def __post_init__(self):
11161123
f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv."
11171124
)
11181125

1126+
enable_partial_send_recv = "disable_partial_send_recv" not in pipeline_parallel_config
1127+
if self.sequence_parallel and enable_partial_send_recv:
1128+
logger.warning(
1129+
"When use pipeline parallel and sequence parallel simultaneously, we should turn off partial send recv."
1130+
)
1131+
enable_partial_send_recv = False
1132+
11191133
strategy.pipeline_configs = {
11201134
"accumulate_steps": self.gradient_accumulation_steps,
11211135
"micro_batch_size": self.per_device_train_batch_size,
1122-
"enable_partial_send_recv": "disable_partial_send_recv" not in pipeline_parallel_config,
1136+
"enable_partial_send_recv": enable_partial_send_recv,
11231137
"p2p_cache_shape": False if "disable_p2p_cache_shape" in pipeline_parallel_config else True,
11241138
# "delay_scale_loss": True, Fix ME
11251139
}

0 commit comments

Comments
 (0)