Skip to content

Commit a79a9aa

Browse files
support dict pp split point (PaddlePaddle#71342)
* support dict pp split point * remove global spec in auto parallel * delete comment * add assert * fix return type bug * recover global_spec because of pir contraint * recover global_spec because of pir contraint * rename * add more assert
1 parent 1812c87 commit a79a9aa

File tree

1 file changed

+37
-20
lines changed

1 file changed

+37
-20
lines changed

python/paddle/distributed/auto_parallel/intermediate/pipeline_parallel.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class SplitPoint(Enum):
6060

6161

6262
class PipelineParallel(ParallelModel):
63-
def __init__(self, model, split_spec, global_spec, pipeline_layers):
63+
def __init__(self, model, split_spec, global_spec, pipeline_layers=None):
6464
super().__init__(model)
6565
self.split_spec = split_spec
6666
self.global_spec = global_spec
@@ -81,10 +81,6 @@ def pipeline_parallel_fn(self, model):
8181
pipeline_stage_num = mesh.get_dim_size("pp")
8282
assert len(self.split_spec) == pipeline_stage_num - 1
8383

84-
name_to_layer = {}
85-
for layer_name, layer in model.named_sublayers():
86-
name_to_layer[layer_name] = layer
87-
8884
def forward_post_hook(layer, input, output):
8985
pipeline_stage_index = layer.pipeline_stage_index
9086
split_point = layer.split_point
@@ -164,8 +160,10 @@ def forward_pre_hook(layer, input):
164160
"SplitPoint.BEGINNING is not supported currently"
165161
)
166162
layer.register_forward_pre_hook(forward_pre_hook)
163+
167164
if self.global_spec:
168165
self.process_global_mesh_layers()
166+
169167
return model
170168

171169
def process_global_mesh_layers(self):
@@ -196,6 +194,7 @@ def forward_post_hook(layer, input, output):
196194
for _ in range(len(g_mesh._shape))
197195
],
198196
)
197+
199198
if isinstance(output, tuple):
200199
global_output = tuple(global_output)
201200
return global_output
@@ -222,30 +221,45 @@ def forward_pre_hook(layer, args, kwargs):
222221
new_args = []
223222
new_kwargs = {}
224223

225-
def reshard_tensor_args(t):
226-
if is_tensor(t) and t.is_dist() and t.process_mesh == g_mesh:
224+
def rshard_not_mesh_match_tensor(arg):
225+
cur_pp_mesh = self.get_mesh(pp_idx)
226+
if (
227+
arg is not None
228+
and is_tensor(arg)
229+
and arg.is_dist()
230+
and arg.process_mesh != cur_pp_mesh
231+
):
227232
return dist.reshard(
228-
t,
229-
self.get_mesh(pp_idx),
233+
arg,
234+
cur_pp_mesh,
230235
[dist.Replicate(), dist.Replicate()],
231236
)
232-
return t
237+
return arg
233238

234239
for arg in args:
235-
new_args.append(reshard_tensor_args(arg))
240+
new_args.append(rshard_not_mesh_match_tensor(arg))
236241

237242
for key, arg in kwargs.items():
238-
new_kwargs[key] = reshard_tensor_args(arg)
243+
new_kwargs[key] = rshard_not_mesh_match_tensor(arg)
239244

240-
return (new_args, new_kwargs)
245+
return (tuple(new_args), new_kwargs)
241246

247+
# wa because of pir in vpp mode send receive bug
242248
for layer_name in self.global_spec:
243249
layer = self.get_layer_by_name(layer_name)
244250
layer.register_forward_post_hook(forward_post_hook)
245251

246-
for layer_name in self.pipeline_layers:
247-
layer = self.get_layer_by_name(layer_name)
248-
layer.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
252+
if self.pipeline_layers is not None:
253+
for layer_name in self.pipeline_layers:
254+
layer = self.get_layer_by_name(layer_name)
255+
layer.register_forward_pre_hook(
256+
forward_pre_hook, with_kwargs=True
257+
)
258+
else:
259+
for layer in self.name_to_layer.values():
260+
layer.register_forward_pre_hook(
261+
forward_pre_hook, with_kwargs=True
262+
)
249263

250264

251265
def pipeline_parallel(model, optimizer=None, config=None):
@@ -366,11 +380,14 @@ def divide_list_indices(n, k):
366380
]
367381
)
368382
else:
383+
sublayer_names = [name for name, _ in model.named_sublayers()]
369384
split_spec_dict = split_spec
370-
if global_spec:
371-
raise NotImplementedError(
372-
"global_spec should be None if split_spec is a dict"
373-
)
385+
for key, value in split_spec_dict.items():
386+
assert (
387+
key in sublayer_names
388+
), f"wrong split layer, expected one of {sublayer_names}"
389+
assert value is SplitPoint.END, "not supported split point at now."
390+
374391
if global_spec:
375392
if isinstance(global_spec, str):
376393
global_spec = [global_spec]

0 commit comments

Comments
 (0)