Skip to content

Commit bc23b8b

Browse files
haohongxianggongenlei
andauthored
[Dygraph] Support sharding stage2/3+dp in GPT-3 model (#2471)
* add sharding+dp * update * code style check Co-authored-by: gongenlei <[email protected]>
1 parent 2cfeadf commit bc23b8b

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

examples/language_model/gpt-3/dygraph/run_pretrain.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
from paddle.distributed import fleet
3838
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
3939
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import DygraphShardingOptimizer
40+
from paddle.fluid.dygraph.parallel import sync_params_buffers
41+
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
4042

4143
# add sharding stage2/3
4244
from paddle.distributed.sharding import group_sharded_parallel
@@ -151,9 +153,10 @@ def do_train(args):
151153
dp_rank = hcg.get_data_parallel_rank()
152154
sharding_rank = hcg.get_sharding_parallel_rank()
153155

154-
# sharding stage2/3 not support hybrid parallel
156+
# sharding stage2/3 not support hybrid parallel now
155157
if args.sharding_stage in [2, 3]:
156-
assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later"
158+
assert args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support tensor/pipeline parallel later"
159+
dp_group = hcg.get_data_parallel_group()
157160

158161
sharding_size = hcg.get_sharding_parallel_world_size()
159162
data_world_rank = dp_rank * sharding_size + sharding_rank
@@ -275,6 +278,11 @@ def do_train(args):
275278
# wrap sharding stage2/3 and add collective group
276279
# TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
277280
if args.sharding_stage in [2, 3]:
281+
if args.dp_degree > 1:
282+
sync_params_buffers(model,
283+
comm_group=dp_group,
284+
src_rank=dp_group.ranks[0])
285+
278286
scaler = scaler if args.use_pure_fp16 else None
279287
model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler,
280288
args.sharding_offload)
@@ -359,6 +367,16 @@ def do_train(args):
359367
loss_mbs.backward()
360368
loss = loss + loss_mbs
361369

370+
if args.sharding_stage in [2, 3] and args.dp_degree > 1:
371+
fused_allreduce_gradients(model.parameters(), hcg)
372+
if args.sharding_stage == 3:
373+
for p in model.parameters():
374+
if hasattr(p, "bw_storage"):
375+
assert p.grad is None, "This case shouldn't happen."
376+
p.bw_storage.scale_(1.0 / dp_group.nranks)
377+
paddle.distributed.all_reduce(
378+
p.bw_storage, group=dp_group)
379+
362380
if args.use_pure_fp16:
363381
if args.sharding_stage in [2, 3]:
364382
scaler.step(optimizer)

0 commit comments

Comments
 (0)