Skip to content

Commit 934a757

Browse files
committed
update register_sequence_parallel_allreduce_hooks
1 parent 37b1a42 commit 934a757

File tree

6 files changed

+12
-30
lines changed

6 files changed

+12
-30
lines changed

llm/alignment/dpo/run_dpo.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
LlamaForCausalLMPipe,
4444
Qwen2ForCausalLM,
4545
Qwen2ForCausalLMPipe,
46-
register_sequence_parallel_allreduce_hooks,
4746
)
4847
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
4948
from paddlenlp.trl import (
@@ -154,10 +153,6 @@ def main():
154153
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
155154
raise NotImplementedError(f"{model.__class__} not support flash mask.")
156155

157-
if training_args.sequence_parallel:
158-
register_sequence_parallel_allreduce_hooks(
159-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
160-
)
161156
if model_args.tokenizer_name_or_path is not None:
162157
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
163158
else:

llm/alignment/kto/run_kto.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
LlamaForCausalLM,
3939
LlamaForCausalLMPipe,
4040
Qwen2ForCausalLM,
41-
register_sequence_parallel_allreduce_hooks,
4241
)
4342
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
4443
from paddlenlp.trl import (
@@ -140,10 +139,6 @@ def main():
140139
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
141140
raise NotImplementedError(f"{model.__class__} not support flash mask.")
142141

143-
if training_args.sequence_parallel:
144-
register_sequence_parallel_allreduce_hooks(
145-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
146-
)
147142
if model_args.tokenizer_name_or_path is not None:
148143
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
149144
else:

llm/alignment/rm/flashmask/run_reward.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,7 @@
3535
get_last_checkpoint,
3636
set_seed,
3737
)
38-
from paddlenlp.transformers import (
39-
AutoConfig,
40-
AutoTokenizer,
41-
register_sequence_parallel_allreduce_hooks,
42-
)
38+
from paddlenlp.transformers import AutoConfig, AutoTokenizer
4339
from paddlenlp.utils.log import logger
4440

4541

@@ -126,10 +122,6 @@ def main():
126122
logger.warning("`flash_mask` must use with zero padding and flash attention.")
127123
model.config.use_flash_attention = True
128124

129-
if model_args.sequence_parallel:
130-
register_sequence_parallel_allreduce_hooks(
131-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
132-
)
133125
if model_args.tokenizer_name_or_path is not None:
134126
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
135127
else:

llm/run_finetune.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
LlamaTokenizer,
5959
Qwen2ForCausalLM,
6060
Qwen2ForCausalLMPipe,
61-
register_sequence_parallel_allreduce_hooks,
6261
)
6362
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
6463
from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
@@ -231,10 +230,6 @@ def neft_post_hook(module, input, output):
231230
else:
232231
raise NotImplementedError("Only support neftune for model with get_input_embeddings")
233232

234-
if training_args.sequence_parallel:
235-
register_sequence_parallel_allreduce_hooks(
236-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
237-
)
238233
# Load tokenizer & dataset
239234
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio)
240235
reft_layers = None

llm/run_pretrain.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
AutoTokenizer,
4242
CosineAnnealingWithWarmupDecay,
4343
LinearAnnealingWithWarmupDecay,
44-
register_sequence_parallel_allreduce_hooks,
4544
)
4645
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
4746
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
@@ -492,11 +491,6 @@ def main():
492491
else:
493492
model = model_class.from_config(config, dtype=dtype)
494493

495-
if training_args.sequence_parallel:
496-
register_sequence_parallel_allreduce_hooks(
497-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
498-
)
499-
500494
if training_args.recompute:
501495
model.recompute_enable()
502496

paddlenlp/trainer/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@
8787
from ..quantization.quantization_linear import QuantizationLinear
8888
except:
8989
QuantizationLinear = None
90+
try:
91+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
92+
register_sequence_parallel_allreduce_hooks,
93+
)
94+
except:
95+
pass
9096
from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance
9197
from ..transformers.model_utils import (
9298
PretrainedModel,
@@ -428,6 +434,11 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
428434
"We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config."
429435
)
430436

437+
if args.sequence_parallel:
438+
register_sequence_parallel_allreduce_hooks(
439+
self.model, args.gradient_accumulation_steps, args.fuse_sequence_parallel_allreduce
440+
)
441+
431442
self.do_grad_scaling = False
432443
self.enable_autocast_context_manager = False
433444
if args.fp16 or args.bf16:

0 commit comments

Comments
 (0)