@@ -52,73 +52,78 @@ def test_job_config_file_cmd_overrides(self):
5252 )
5353 assert config .job .dump_folder == "/tmp/test_tt/"
5454
55- def test_parse_pp_split_points (self ):
56- toml_splits = ["layers.2" , "layers.4" , "layers.6" ]
57- cmdline_splits = ["layers.1" , "layers.3" , "layers.5" ]
58- # no split points specified
59- config_manager = ConfigManager ()
60- config = config_manager .parse_args (
61- [
62- "--job.config_file" ,
63- "./torchtitan/models/llama3/train_configs/debug_model.toml" ,
64- ]
65- )
66- assert config .parallelism .pipeline_parallel_split_points == []
55+ def test_parse_module_fqns_per_model_part (self ):
56+ toml_chunks = [
57+ ["tok_embeddings" , "layers.0" ],
58+ ["layers.1" , "layers.2" ],
59+ ["layers.3" , "norm" , "output" ],
60+ ]
61+ cmdline_chunks = [
62+ ["tok_embeddings" , "layers.0" , "layers.1" ],
63+ ["layers.2" , "layers.3" , "norm" , "output" ],
64+ ]
6765
68- # toml has no split points, but cmdline splits are specified
66+ # no module names specified
6967 config_manager = ConfigManager ()
7068 config = config_manager .parse_args (
7169 [
7270 "--job.config_file" ,
7371 "./torchtitan/models/llama3/train_configs/debug_model.toml" ,
74- "--parallelism.pipeline_parallel_split_points" ,
75- "," .join (cmdline_splits ),
7672 ]
7773 )
78- assert (
79- config .parallelism .pipeline_parallel_split_points == cmdline_splits
80- ), config .parallelism .pipeline_parallel_split_points
74+ assert config .parallelism .module_fqns_per_model_part == []
8175
82- # toml has split points , cmdline does not
76+ # toml has module names , cmdline does not
8377 with tempfile .NamedTemporaryFile () as fp :
8478 with open (fp .name , "wb" ) as f :
8579 tomli_w .dump (
8680 {
8781 "parallelism" : {
88- "pipeline_parallel_split_points " : toml_splits ,
82+ "module_fqns_per_model_part " : toml_chunks ,
8983 }
9084 },
9185 f ,
9286 )
9387 config_manager = ConfigManager ()
9488 config = config_manager .parse_args (["--job.config_file" , fp .name ])
9589 assert (
96- config .parallelism .pipeline_parallel_split_points == toml_splits
97- ), config .parallelism .pipeline_parallel_split_points
90+ config .parallelism .module_fqns_per_model_part == toml_chunks
91+ ), config .parallelism .module_fqns_per_model_part
9892
99- # toml has split points, cmdline overrides them
93+ # test that the field accepts list of lists structure
10094 with tempfile .NamedTemporaryFile () as fp :
10195 with open (fp .name , "wb" ) as f :
10296 tomli_w .dump (
10397 {
10498 "parallelism" : {
105- "pipeline_parallel_split_points " : toml_splits ,
99+ "module_fqns_per_model_part " : cmdline_chunks ,
106100 }
107101 },
108102 f ,
109103 )
110104 config_manager = ConfigManager ()
111- config = config_manager .parse_args (
112- [
113- "--job.config_file" ,
114- fp .name ,
115- "--parallelism.pipeline_parallel_split_points" ,
116- "," .join (cmdline_splits ),
117- ]
118- )
105+ config = config_manager .parse_args (["--job.config_file" , fp .name ])
106+ assert (
107+ config .parallelism .module_fqns_per_model_part == cmdline_chunks
108+ ), config .parallelism .module_fqns_per_model_part
109+
110+ # test empty chunks are handled correctly
111+ empty_chunks = [[], ["tok_embeddings" ], []]
112+ with tempfile .NamedTemporaryFile () as fp :
113+ with open (fp .name , "wb" ) as f :
114+ tomli_w .dump (
115+ {
116+ "parallelism" : {
117+ "module_fqns_per_model_part" : empty_chunks ,
118+ }
119+ },
120+ f ,
121+ )
122+ config_manager = ConfigManager ()
123+ config = config_manager .parse_args (["--job.config_file" , fp .name ])
119124 assert (
120- config .parallelism .pipeline_parallel_split_points == cmdline_splits
121- ), config .parallelism .pipeline_parallel_split_points
125+ config .parallelism .module_fqns_per_model_part == empty_chunks
126+ ), config .parallelism .module_fqns_per_model_part
122127
123128 def test_parse_exclude_from_loading (self ):
124129 toml_splits = ["optimizer" , "dataloader" ]
0 commit comments