Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def paddlenlp_load(path, return_numpy=False):
return paddle.load(path, return_numpy=return_numpy)


def is_dp_group_support_in_group_sharded_parallel():
return "dp_group" in set(inspect.signature(paddle.distributed.sharding.group_sharded_parallel).parameters.keys())


__all__ = ["Trainer"]


Expand Down Expand Up @@ -671,13 +675,12 @@ def train(
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):

# Maunally collect gradients
# Maunally collect gradients when group_sharded_parallel can't accept dp_group
# Case 1: Use sharding stage 2/3 with dp
# Case 2: Use recompute and dp
# local_rank != -1 don't means dp in networks.
if self.sharding and ShardingOption.SHARD_OP not in self.args.sharding:
if self.args.dp_degree > 1:
if self.args.dp_degree > 1 and not is_dp_group_support_in_group_sharded_parallel():
fused_allreduce_gradients(model.parameters(), fleet.get_hybrid_communicate_group())
if ShardingOption.FULL_SHARD in self.args.sharding:
# Why need sync on parm again ?
Expand Down Expand Up @@ -1199,7 +1202,7 @@ def _wrap_model(self, model, training=True):
self.optimizer = fleet.distributed_optimizer(self.optimizer)
else:
# sync params (broadcast) buffers in dp group
if self.args.dp_degree > 1:
if not is_dp_group_support_in_group_sharded_parallel() and self.args.dp_degree > 1:
try:
from paddle.fluid.dygraph.parallel import sync_params_buffers
except ImportError:
Expand All @@ -1220,8 +1223,21 @@ def _wrap_model(self, model, training=True):

from paddle.distributed.sharding import group_sharded_parallel

# add dp_group and exclude_layer params
# https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel
extra_kwargs = {}
if is_dp_group_support_in_group_sharded_parallel():
extra_kwargs["dp_group"] = self.dp_group
extra_kwargs["exclude_layer"] = ["GroupNorm"]

model, optimizer, _ = group_sharded_parallel(
model, self.optimizer, level=level, scaler=None, group=self.sharding_group, offload=cpu_offload
model,
self.optimizer,
level=level,
scaler=None,
group=self.sharding_group,
offload=cpu_offload,
**extra_kwargs,
)
self.optimizer = optimizer

Expand Down