Skip to content

Commit f457774

Browse files
sneaxiywentaoyu
authored andcommitted
part-2 cherry from: Make FLAGS_force_align_vpp_grad_sum_order default to false (PaddlePaddle#54937)
* make FLAGS_force_align_vpp_grad_sum_order default to false * polish code
1 parent 71fb1a0 commit f457774

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, clip, hcg):
5151
self.not_sharding_stage1 = True
5252
self._vpp_chunk_num = None
5353
self._force_align_vpp_grad_sum_order = distutils.util.strtobool(
54-
os.getenv('FLAGS_force_align_vpp_grad_sum_order', '1')
54+
os.getenv('FLAGS_force_align_vpp_grad_sum_order', '0')
5555
)
5656

5757
def _get_vpp_chunk_num(self, params_grads):
@@ -220,9 +220,10 @@ def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
220220

221221
@no_grad()
222222
def _dygraph_clip(self, params_grads):
223-
chunk_num = self._get_vpp_chunk_num(params_grads)
224-
if chunk_num > 0 and self._force_align_vpp_grad_sum_order:
225-
return self._vpp_dygraph_clip(params_grads, chunk_num)
223+
if self._force_align_vpp_grad_sum_order:
224+
chunk_num = self._get_vpp_chunk_num(params_grads)
225+
if chunk_num > 0:
226+
return self._vpp_dygraph_clip(params_grads, chunk_num)
226227

227228
sum_square_dist_fp16 = []
228229
sum_square_dist_bf16 = []

0 commit comments

Comments
 (0)