|
37 | 37 | from paddle.distributed import fleet |
38 | 38 | from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker |
39 | 39 | from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import DygraphShardingOptimizer |
| 40 | +from paddle.fluid.dygraph.parallel import sync_params_buffers |
| 41 | +from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients |
40 | 42 |
|
41 | 43 | # add sharding stage2/3 |
42 | 44 | from paddle.distributed.sharding import group_sharded_parallel |
@@ -151,9 +153,10 @@ def do_train(args): |
151 | 153 | dp_rank = hcg.get_data_parallel_rank() |
152 | 154 | sharding_rank = hcg.get_sharding_parallel_rank() |
153 | 155 |
|
154 | | - # sharding stage2/3 not support hybrid parallel |
| 156 | + # sharding stage2/3 not support hybrid parallel now |
155 | 157 | if args.sharding_stage in [2, 3]: |
156 | | - assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later" |
| 158 | + assert args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support tensor/pipeline parallel later" |
| 159 | + dp_group = hcg.get_data_parallel_group() |
157 | 160 |
|
158 | 161 | sharding_size = hcg.get_sharding_parallel_world_size() |
159 | 162 | data_world_rank = dp_rank * sharding_size + sharding_rank |
@@ -275,6 +278,11 @@ def do_train(args): |
275 | 278 | # wrap sharding stage2/3 and add collective group |
276 | 279 | # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature |
277 | 280 | if args.sharding_stage in [2, 3]: |
| 281 | + if args.dp_degree > 1: |
| 282 | + sync_params_buffers(model, |
| 283 | + comm_group=dp_group, |
| 284 | + src_rank=dp_group.ranks[0]) |
| 285 | + |
278 | 286 | scaler = scaler if args.use_pure_fp16 else None |
279 | 287 | model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler, |
280 | 288 | args.sharding_offload) |
@@ -359,6 +367,16 @@ def do_train(args): |
359 | 367 | loss_mbs.backward() |
360 | 368 | loss = loss + loss_mbs |
361 | 369 |
|
| 370 | + if args.sharding_stage in [2, 3] and args.dp_degree > 1: |
| 371 | + fused_allreduce_gradients(model.parameters(), hcg) |
| 372 | + if args.sharding_stage == 3: |
| 373 | + for p in model.parameters(): |
| 374 | + if hasattr(p, "bw_storage"): |
| 375 | + assert p.grad is None, "This case shouldn't happen." |
| 376 | + p.bw_storage.scale_(1.0 / dp_group.nranks) |
| 377 | + paddle.distributed.all_reduce( |
| 378 | + p.bw_storage, group=dp_group) |
| 379 | + |
362 | 380 | if args.use_pure_fp16: |
363 | 381 | if args.sharding_stage in [2, 3]: |
364 | 382 | scaler.step(optimizer) |
|
0 commit comments