Skip to content

Commit d7d3090

Browse files
authored
[Cherry-Pick][HybridParallel]Fix pipeline in dygraph (PaddlePaddle#33097)
* [HybridParallel]Fix pipeline in dygraph (PaddlePaddle#33007) * fix pipeline * fix mp pp dp * fix utest of hybrid parallel * add utest for tuple * fix utest (PaddlePaddle#33108)
1 parent 8fe6d55 commit d7d3090

14 files changed

+649
-378
lines changed

python/paddle/distributed/fleet/base/topology.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,8 @@ def get_pipe_parallel_group(self):
253253
# check parallel group
254254
def get_check_parallel_group(self):
255255
return self._check_comm_group
256+
257+
def get_rank_from_stage(self, stage_id):
258+
coord = self._topo.get_coord(self.global_rank)
259+
tf = coord._replace(pipe=stage_id)._asdict()
260+
return self._topo.get_rank(**tf)

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@ def __init__(self, optimizer, hcg, strategy):
8989
self._inner_opt = optimizer
9090
self._strategy = strategy
9191
self._hcg = hcg
92-
self._is_mp = (
93-
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)
92+
93+
self._use_dp_mode = (
94+
self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL)
95+
9496
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)
9597

9698
if isinstance(self._inner_opt._grad_clip,
97-
ClipGradByGlobalNorm) and self._is_mp:
99+
ClipGradByGlobalNorm) and not self._use_dp_mode:
98100
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
99101
"optmizer'grad clip will be changed.")
100102
self._inner_opt._grad_clip = HybridParallelClipGrad(
@@ -103,7 +105,7 @@ def __init__(self, optimizer, hcg, strategy):
103105
@imperative_base.no_grad
104106
@framework.dygraph_only
105107
def step(self):
106-
if self._is_mp and self._need_dp:
108+
if not self._use_dp_mode and self._need_dp:
107109
fused_allreduce_gradients(
108110
list(self._inner_opt._parameter_list), self._hcg)
109111
self._inner_opt.step()
@@ -119,7 +121,7 @@ def minimize(self,
119121
parameter_list = parameters if parameters \
120122
else self._parameter_list
121123

122-
if self._is_mp and self._need_dp:
124+
if not self._use_dp_mode and self._need_dp:
123125
fused_allreduce_gradients(list(parameter_list), self._hcg)
124126

125127
return self._inner_opt.minimize(loss, startup_program, parameters,

0 commit comments

Comments
 (0)