@@ -132,6 +132,10 @@ def paddlenlp_load(path, return_numpy=False):
132132 return paddle .load (path , return_numpy = return_numpy )
133133
134134
135+ def is_dp_group_support_in_group_sharded_parallel ():
136+ return "dp_group" in set (inspect .signature (paddle .distributed .sharding .group_sharded_parallel ).parameters .keys ())
137+
138+
135139__all__ = ["Trainer" ]
136140
137141
@@ -671,13 +675,12 @@ def train(
671675 steps_in_epoch <= args .gradient_accumulation_steps
672676 and (step + 1 ) == steps_in_epoch
673677 ):
674-
675- # Maunally collect gradients
678+ # Maunally collect gradients when group_sharded_parallel can't accept dp_group
676679 # Case 1: Use sharding stage 2/3 with dp
677680 # Case 2: Use recompute and dp
678681 # local_rank != -1 don't means dp in networks.
679682 if self .sharding and ShardingOption .SHARD_OP not in self .args .sharding :
680- if self .args .dp_degree > 1 :
683+ if self .args .dp_degree > 1 and not is_dp_group_support_in_group_sharded_parallel () :
681684 fused_allreduce_gradients (model .parameters (), fleet .get_hybrid_communicate_group ())
682685 if ShardingOption .FULL_SHARD in self .args .sharding :
683686 # Why need sync on parm again ?
@@ -1199,7 +1202,7 @@ def _wrap_model(self, model, training=True):
11991202 self .optimizer = fleet .distributed_optimizer (self .optimizer )
12001203 else :
12011204 # sync params (broadcast) buffers in dp group
1202- if self .args .dp_degree > 1 :
1205+ if not is_dp_group_support_in_group_sharded_parallel () and self .args .dp_degree > 1 :
12031206 try :
12041207 from paddle .fluid .dygraph .parallel import sync_params_buffers
12051208 except ImportError :
@@ -1220,8 +1223,21 @@ def _wrap_model(self, model, training=True):
12201223
12211224 from paddle .distributed .sharding import group_sharded_parallel
12221225
1226+ # add dp_group and exclude_layer params
1227+ # https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel
1228+ extra_kwargs = {}
1229+ if is_dp_group_support_in_group_sharded_parallel ():
1230+ extra_kwargs ["dp_group" ] = self .dp_group
1231+ extra_kwargs ["exclude_layer" ] = ["GroupNorm" ]
1232+
12231233 model , optimizer , _ = group_sharded_parallel (
1224- model , self .optimizer , level = level , scaler = None , group = self .sharding_group , offload = cpu_offload
1234+ model ,
1235+ self .optimizer ,
1236+ level = level ,
1237+ scaler = None ,
1238+ group = self .sharding_group ,
1239+ offload = cpu_offload ,
1240+ ** extra_kwargs ,
12251241 )
12261242 self .optimizer = optimizer
12271243
0 commit comments