Skip to content

Commit 12ba2fb

Browse files
add more assert
1 parent 343c109 commit 12ba2fb

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

python/paddle/distributed/auto_parallel/intermediate/pipeline_parallel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,10 @@ def divide_list_indices(n, k):
383383
sublayer_names = [name for name, _ in model.named_sublayers()]
384384
split_spec_dict = split_spec
385385
for key, value in split_spec_dict.items():
386-
assert key in sublayer_names, "wrong split point"
387-
assert value is SplitPoint.END, "wrong split point flag"
386+
assert (
387+
key in sublayer_names
388+
), f"wrong split layer, expected one of {sublayer_names}"
389+
assert value is SplitPoint.END, "not supported split point at now."
388390

389391
if global_spec:
390392
if isinstance(global_spec, str):

0 commit comments

Comments
 (0)