Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,13 @@ def train(
# sharding
# stage1. the same as ddp
# stage2. manualy collect gradient on dp group

dp_master_grad = (
self.args.world_size > 1 and self.args.amp_master_grad and not self.args.use_hybrid_parallel
)
if dp_master_grad:
is_no_sync = True

if is_no_sync:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
Expand Down Expand Up @@ -884,6 +891,9 @@ def train(
self.timers and self.timers("all-reduce").stop()
self.timers and self.timers("optimizer-step").start()

if dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)

# pipeline parallel mode, handle gradient merge here
if args.pipeline_parallel_degree > 1 and enable_delay_scale_loss:
for p in model._layers.parameters():
Expand Down Expand Up @@ -1536,14 +1546,25 @@ def _wrap_model(self, model, training=True):
else:
model, self.optimizer = decorated

if self.args.world_size == 1:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
assert self.optimizer is not None, "optimizer is empty!"
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)

# Multi-gpu training
if self.args.world_size > 1 and not self.args.use_hybrid_parallel:
model = paddle.DataParallel(model)
# Distributed training (should be after fp16 initialization)

if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的model的话,dp已经包了一层。确认一下 MixPrecisionLayer 再包一次影响不?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里这样包一次是可以的

assert self.optimizer is not None, "optimizer is empty!"
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)

in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1
in_sharding_parallel_mode = self.sharding is not None
in_tensor_parallel_model = self.args.tensor_parallel_degree > 1
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1

# Pipeline mode
if in_pipeline_parallel_mode:
Expand Down Expand Up @@ -1669,7 +1690,7 @@ def get_expected_keys(inputs, keys):
self.optimizer = optimizer

# pure tesnor parallel mode, no pipeline_parallel, no sharding.
if not in_pipeline_parallel_mode and not in_sharding_parallel_mode and in_tensor_parallel_model:
if not in_pipeline_parallel_mode and not in_sharding_parallel_mode and in_tensor_parallel_mode:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use

Expand Down
9 changes: 0 additions & 9 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,15 +806,6 @@ def __post_init__(self):
self.distributed_dataloader = False

if self.amp_master_grad:
if (
self.pipeline_parallel_degree <= 1
and self.tensor_parallel_degree <= 1
and (not self.sharding or ShardingOption.FULL_SHARD in self.sharding)
):
raise ValueError(
"Temporarily amp master grad only support for tensor/pipeline/sharding"
" (stage 1 and stage 2) parallel. Please set amp_master_grad to False."
)
if not (self.bf16 or self.fp16):
logger.warning("set amp_master_grad to false since amp is disabled.")
self.amp_master_grad = False
Expand Down