Skip to content

Commit 1b0892c

Browse files
committed
Refactor PP splitting
1 parent d282cf2 commit 1b0892c

File tree

5 files changed

+429
-459
lines changed

5 files changed

+429
-459
lines changed

tests/unit_tests/test_job_config.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -52,73 +52,78 @@ def test_job_config_file_cmd_overrides(self):
5252
)
5353
assert config.job.dump_folder == "/tmp/test_tt/"
5454

55-
def test_parse_pp_split_points(self):
56-
toml_splits = ["layers.2", "layers.4", "layers.6"]
57-
cmdline_splits = ["layers.1", "layers.3", "layers.5"]
58-
# no split points specified
59-
config_manager = ConfigManager()
60-
config = config_manager.parse_args(
61-
[
62-
"--job.config_file",
63-
"./torchtitan/models/llama3/train_configs/debug_model.toml",
64-
]
65-
)
66-
assert config.parallelism.pipeline_parallel_split_points == []
55+
def test_parse_module_fqns_per_model_part(self):
56+
toml_chunks = [
57+
["tok_embeddings", "layers.0"],
58+
["layers.1", "layers.2"],
59+
["layers.3", "norm", "output"],
60+
]
61+
cmdline_chunks = [
62+
["tok_embeddings", "layers.0", "layers.1"],
63+
["layers.2", "layers.3", "norm", "output"],
64+
]
6765

68-
# toml has no split points, but cmdline splits are specified
66+
# no module names specified
6967
config_manager = ConfigManager()
7068
config = config_manager.parse_args(
7169
[
7270
"--job.config_file",
7371
"./torchtitan/models/llama3/train_configs/debug_model.toml",
74-
"--parallelism.pipeline_parallel_split_points",
75-
",".join(cmdline_splits),
7672
]
7773
)
78-
assert (
79-
config.parallelism.pipeline_parallel_split_points == cmdline_splits
80-
), config.parallelism.pipeline_parallel_split_points
74+
assert config.parallelism.module_fqns_per_model_part == []
8175

82-
# toml has split points, cmdline does not
76+
# toml has module names, cmdline does not
8377
with tempfile.NamedTemporaryFile() as fp:
8478
with open(fp.name, "wb") as f:
8579
tomli_w.dump(
8680
{
8781
"parallelism": {
88-
"pipeline_parallel_split_points": toml_splits,
82+
"module_fqns_per_model_part": toml_chunks,
8983
}
9084
},
9185
f,
9286
)
9387
config_manager = ConfigManager()
9488
config = config_manager.parse_args(["--job.config_file", fp.name])
9589
assert (
96-
config.parallelism.pipeline_parallel_split_points == toml_splits
97-
), config.parallelism.pipeline_parallel_split_points
90+
config.parallelism.module_fqns_per_model_part == toml_chunks
91+
), config.parallelism.module_fqns_per_model_part
9892

99-
# toml has split points, cmdline overrides them
93+
# test that the field accepts list of lists structure
10094
with tempfile.NamedTemporaryFile() as fp:
10195
with open(fp.name, "wb") as f:
10296
tomli_w.dump(
10397
{
10498
"parallelism": {
105-
"pipeline_parallel_split_points": toml_splits,
99+
"module_fqns_per_model_part": cmdline_chunks,
106100
}
107101
},
108102
f,
109103
)
110104
config_manager = ConfigManager()
111-
config = config_manager.parse_args(
112-
[
113-
"--job.config_file",
114-
fp.name,
115-
"--parallelism.pipeline_parallel_split_points",
116-
",".join(cmdline_splits),
117-
]
118-
)
105+
config = config_manager.parse_args(["--job.config_file", fp.name])
106+
assert (
107+
config.parallelism.module_fqns_per_model_part == cmdline_chunks
108+
), config.parallelism.module_fqns_per_model_part
109+
110+
# test empty chunks are handled correctly
111+
empty_chunks = [[], ["tok_embeddings"], []]
112+
with tempfile.NamedTemporaryFile() as fp:
113+
with open(fp.name, "wb") as f:
114+
tomli_w.dump(
115+
{
116+
"parallelism": {
117+
"module_fqns_per_model_part": empty_chunks,
118+
}
119+
},
120+
f,
121+
)
122+
config_manager = ConfigManager()
123+
config = config_manager.parse_args(["--job.config_file", fp.name])
119124
assert (
120-
config.parallelism.pipeline_parallel_split_points == cmdline_splits
121-
), config.parallelism.pipeline_parallel_split_points
125+
config.parallelism.module_fqns_per_model_part == empty_chunks
126+
), config.parallelism.module_fqns_per_model_part
122127

123128
def test_parse_exclude_from_loading(self):
124129
toml_splits = ["optimizer", "dataloader"]

torchtitan/config/job_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ class Parallelism:
290290

291291
pipeline_parallel_split_points: list[str] = field(default_factory=list)
292292
"""
293+
DEPRECATED: Use module_fqns_per_model_part instead.
293294
Specify comma-separated names of modules to use as the beginning of a split point.
294295
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
295296
the first containing all the layers up to layers.0,
@@ -299,6 +300,16 @@ class Parallelism:
299300
but currently the split points must be specified manually.
300301
"""
301302

303+
module_fqns_per_model_part: list[list[str]] = field(default_factory=list)
304+
"""
305+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.
306+
Each inner list represents one model chunk and contains the module names that belong to that chunk.
307+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
308+
will create 3 chunks: the first containing tok_embeddings and layers.0,
309+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
310+
This provides more explicit control over which modules belong to each chunk compared to split points.
311+
"""
312+
302313
pipeline_parallel_layers_per_stage: int | None = None
303314
"""
304315
The number of layers per (virtual) pipeline stage. If specified, the split points will be

0 commit comments

Comments
 (0)