Skip to content

Commit cd0afab

Browse files
authored
[Trainer] add dp_group and exclude_layer params (#4930)
* add dp_group and exclude_layer params * update * update
1 parent b5c4840 commit cd0afab

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ def paddlenlp_load(path, return_numpy=False):
132132
return paddle.load(path, return_numpy=return_numpy)
133133

134134

135+
def is_dp_group_support_in_group_sharded_parallel():
136+
return "dp_group" in set(inspect.signature(paddle.distributed.sharding.group_sharded_parallel).parameters.keys())
137+
138+
135139
__all__ = ["Trainer"]
136140

137141

@@ -671,13 +675,12 @@ def train(
671675
steps_in_epoch <= args.gradient_accumulation_steps
672676
and (step + 1) == steps_in_epoch
673677
):
674-
675-
# Maunally collect gradients
678+
# Maunally collect gradients when group_sharded_parallel can't accept dp_group
676679
# Case 1: Use sharding stage 2/3 with dp
677680
# Case 2: Use recompute and dp
678681
# local_rank != -1 don't means dp in networks.
679682
if self.sharding and ShardingOption.SHARD_OP not in self.args.sharding:
680-
if self.args.dp_degree > 1:
683+
if self.args.dp_degree > 1 and not is_dp_group_support_in_group_sharded_parallel():
681684
fused_allreduce_gradients(model.parameters(), fleet.get_hybrid_communicate_group())
682685
if ShardingOption.FULL_SHARD in self.args.sharding:
683686
# Why need sync on parm again ?
@@ -1199,7 +1202,7 @@ def _wrap_model(self, model, training=True):
11991202
self.optimizer = fleet.distributed_optimizer(self.optimizer)
12001203
else:
12011204
# sync params (broadcast) buffers in dp group
1202-
if self.args.dp_degree > 1:
1205+
if not is_dp_group_support_in_group_sharded_parallel() and self.args.dp_degree > 1:
12031206
try:
12041207
from paddle.fluid.dygraph.parallel import sync_params_buffers
12051208
except ImportError:
@@ -1220,8 +1223,21 @@ def _wrap_model(self, model, training=True):
12201223

12211224
from paddle.distributed.sharding import group_sharded_parallel
12221225

1226+
# add dp_group and exclude_layer params
1227+
# https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel
1228+
extra_kwargs = {}
1229+
if is_dp_group_support_in_group_sharded_parallel():
1230+
extra_kwargs["dp_group"] = self.dp_group
1231+
extra_kwargs["exclude_layer"] = ["GroupNorm"]
1232+
12231233
model, optimizer, _ = group_sharded_parallel(
1224-
model, self.optimizer, level=level, scaler=None, group=self.sharding_group, offload=cpu_offload
1234+
model,
1235+
self.optimizer,
1236+
level=level,
1237+
scaler=None,
1238+
group=self.sharding_group,
1239+
offload=cpu_offload,
1240+
**extra_kwargs,
12251241
)
12261242
self.optimizer = optimizer
12271243

0 commit comments

Comments
 (0)