@@ -89,12 +89,14 @@ def __init__(self, optimizer, hcg, strategy):
8989 self ._inner_opt = optimizer
9090 self ._strategy = strategy
9191 self ._hcg = hcg
92- self ._is_mp = (
93- self ._hcg .get_parallel_mode () == ParallelMode .TENSOR_PARALLEL )
92+
93+ self ._use_dp_mode = (
94+ self ._hcg .get_parallel_mode () == ParallelMode .DATA_PARALLEL )
95+
9496 self ._need_dp = (self ._hcg .get_data_parallel_world_size () > 1 )
9597
9698 if isinstance (self ._inner_opt ._grad_clip ,
97- ClipGradByGlobalNorm ) and self ._is_mp :
99+ ClipGradByGlobalNorm ) and not self ._use_dp_mode :
98100 logger .warning ("using ClipGradByGlobalNorm in TensorParallel, the origin " \
99101 "optmizer'grad clip will be changed." )
100102 self ._inner_opt ._grad_clip = HybridParallelClipGrad (
@@ -103,7 +105,7 @@ def __init__(self, optimizer, hcg, strategy):
103105 @imperative_base .no_grad
104106 @framework .dygraph_only
105107 def step (self ):
106- if self ._is_mp and self ._need_dp :
108+ if not self ._use_dp_mode and self ._need_dp :
107109 fused_allreduce_gradients (
108110 list (self ._inner_opt ._parameter_list ), self ._hcg )
109111 self ._inner_opt .step ()
@@ -119,7 +121,7 @@ def minimize(self,
119121 parameter_list = parameters if parameters \
120122 else self ._parameter_list
121123
122- if self ._is_mp and self ._need_dp :
124+ if not self ._use_dp_mode and self ._need_dp :
123125 fused_allreduce_gradients (list (parameter_list ), self ._hcg )
124126
125127 return self ._inner_opt .minimize (loss , startup_program , parameters ,
0 commit comments