Skip to content

Commit 1a21cd5

Browse files
committed
update register_sequence_parallel_allreduce_hooks
1 parent 37b1a42 commit 1a21cd5

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

llm/run_finetune.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
LlamaTokenizer,
5959
Qwen2ForCausalLM,
6060
Qwen2ForCausalLMPipe,
61-
register_sequence_parallel_allreduce_hooks,
6261
)
6362
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
6463
from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
@@ -231,10 +230,6 @@ def neft_post_hook(module, input, output):
231230
else:
232231
raise NotImplementedError("Only support neftune for model with get_input_embeddings")
233232

234-
if training_args.sequence_parallel:
235-
register_sequence_parallel_allreduce_hooks(
236-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
237-
)
238233
# Load tokenizer & dataset
239234
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio)
240235
reft_layers = None

llm/run_pretrain.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
AutoTokenizer,
4242
CosineAnnealingWithWarmupDecay,
4343
LinearAnnealingWithWarmupDecay,
44-
register_sequence_parallel_allreduce_hooks,
4544
)
4645
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
4746
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
@@ -492,11 +491,6 @@ def main():
492491
else:
493492
model = model_class.from_config(config, dtype=dtype)
494493

495-
if training_args.sequence_parallel:
496-
register_sequence_parallel_allreduce_hooks(
497-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
498-
)
499-
500494
if training_args.recompute:
501495
model.recompute_enable()
502496

paddlenlp/trainer/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@
8787
from ..quantization.quantization_linear import QuantizationLinear
8888
except:
8989
QuantizationLinear = None
90+
try:
91+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
92+
register_sequence_parallel_allreduce_hooks,
93+
)
94+
except:
95+
pass
9096
from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance
9197
from ..transformers.model_utils import (
9298
PretrainedModel,
@@ -428,6 +434,11 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
428434
"We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config."
429435
)
430436

437+
if args.sequence_parallel:
438+
register_sequence_parallel_allreduce_hooks(
439+
self.model, args.gradient_accumulation_steps, args.fuse_sequence_parallel_allreduce
440+
)
441+
431442
self.do_grad_scaling = False
432443
self.enable_autocast_context_manager = False
433444
if args.fp16 or args.bf16:

0 commit comments

Comments
 (0)