@@ -60,7 +60,7 @@ class SplitPoint(Enum):
6060
6161
6262class PipelineParallel (ParallelModel ):
63- def __init__ (self , model , split_spec , global_spec , pipeline_layers ):
63+ def __init__ (self , model , split_spec , global_spec , pipeline_layers = None ):
6464 super ().__init__ (model )
6565 self .split_spec = split_spec
6666 self .global_spec = global_spec
@@ -81,10 +81,6 @@ def pipeline_parallel_fn(self, model):
8181 pipeline_stage_num = mesh .get_dim_size ("pp" )
8282 assert len (self .split_spec ) == pipeline_stage_num - 1
8383
84- name_to_layer = {}
85- for layer_name , layer in model .named_sublayers ():
86- name_to_layer [layer_name ] = layer
87-
8884 def forward_post_hook (layer , input , output ):
8985 pipeline_stage_index = layer .pipeline_stage_index
9086 split_point = layer .split_point
@@ -164,8 +160,10 @@ def forward_pre_hook(layer, input):
164160 "SplitPoint.BEGINNING is not supported currently"
165161 )
166162 layer .register_forward_pre_hook (forward_pre_hook )
163+
167164 if self .global_spec :
168165 self .process_global_mesh_layers ()
166+
169167 return model
170168
171169 def process_global_mesh_layers (self ):
@@ -196,6 +194,7 @@ def forward_post_hook(layer, input, output):
196194 for _ in range (len (g_mesh ._shape ))
197195 ],
198196 )
197+
199198 if isinstance (output , tuple ):
200199 global_output = tuple (global_output )
201200 return global_output
@@ -222,30 +221,45 @@ def forward_pre_hook(layer, args, kwargs):
222221 new_args = []
223222 new_kwargs = {}
224223
225- def reshard_tensor_args (t ):
226- if is_tensor (t ) and t .is_dist () and t .process_mesh == g_mesh :
224+ def rshard_not_mesh_match_tensor (arg ):
225+ cur_pp_mesh = self .get_mesh (pp_idx )
226+ if (
227+ arg is not None
228+ and is_tensor (arg )
229+ and arg .is_dist ()
230+ and arg .process_mesh != cur_pp_mesh
231+ ):
227232 return dist .reshard (
228- t ,
229- self . get_mesh ( pp_idx ) ,
233+ arg ,
234+ cur_pp_mesh ,
230235 [dist .Replicate (), dist .Replicate ()],
231236 )
232- return t
237+ return arg
233238
234239 for arg in args :
235- new_args .append (reshard_tensor_args (arg ))
240+ new_args .append (rshard_not_mesh_match_tensor (arg ))
236241
237242 for key , arg in kwargs .items ():
238- new_kwargs [key ] = reshard_tensor_args (arg )
243+ new_kwargs [key ] = rshard_not_mesh_match_tensor (arg )
239244
240- return (new_args , new_kwargs )
245+ return (tuple ( new_args ) , new_kwargs )
241246
247+ # wa because of pir in vpp mode send receive bug
242248 for layer_name in self .global_spec :
243249 layer = self .get_layer_by_name (layer_name )
244250 layer .register_forward_post_hook (forward_post_hook )
245251
246- for layer_name in self .pipeline_layers :
247- layer = self .get_layer_by_name (layer_name )
248- layer .register_forward_pre_hook (forward_pre_hook , with_kwargs = True )
252+ if self .pipeline_layers is not None :
253+ for layer_name in self .pipeline_layers :
254+ layer = self .get_layer_by_name (layer_name )
255+ layer .register_forward_pre_hook (
256+ forward_pre_hook , with_kwargs = True
257+ )
258+ else :
259+ for layer in self .name_to_layer .values ():
260+ layer .register_forward_pre_hook (
261+ forward_pre_hook , with_kwargs = True
262+ )
249263
250264
251265def pipeline_parallel (model , optimizer = None , config = None ):
@@ -366,11 +380,14 @@ def divide_list_indices(n, k):
366380 ]
367381 )
368382 else :
383+ sublayer_names = [name for name , _ in model .named_sublayers ()]
369384 split_spec_dict = split_spec
370- if global_spec :
371- raise NotImplementedError (
372- "global_spec should be None if split_spec is a dict"
373- )
385+ for key , value in split_spec_dict .items ():
386+ assert (
387+ key in sublayer_names
388+ ), f"wrong split layer, expected one of { sublayer_names } "
389+ assert value is SplitPoint .END , "not supported split point at now."
390+
374391 if global_spec :
375392 if isinstance (global_spec , str ):
376393 global_spec = [global_spec ]
0 commit comments