@@ -1022,6 +1022,13 @@ def __post_init__(self):
10221022 logger .warning ("set amp_master_grad to false since amp is disabled." )
10231023 self .amp_master_grad = False
10241024
1025+ def split_parallel_config (parallel_config ):
1026+ if "," in parallel_config :
1027+ parallel_config = set (parallel_config .split ("," ))
1028+ else :
1029+ parallel_config = set (parallel_config .split (" " ))
1030+ return parallel_config
1031+
10251032 # use_hybrid_parallel
10261033 if self .use_hybrid_parallel :
10271034
@@ -1039,10 +1046,7 @@ def __post_init__(self):
10391046 strategy = fleet .DistributedStrategy ()
10401047 assert self .data_parallel_config == "" , "data_parallle_config is not supported in hybrid parallel"
10411048 if self .pipeline_parallel_degree > 1 :
1042- if " " in self .pipeline_parallel_config :
1043- pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1044- else :
1045- pipeline_parallel_config = set (self .pipeline_parallel_config .split ("," ))
1049+ pipeline_parallel_config = split_parallel_config (self .pipeline_parallel_config )
10461050 for x in pipeline_parallel_config :
10471051 if len (x ) > 0 :
10481052 if x not in [
@@ -1116,10 +1120,7 @@ def __post_init__(self):
11161120 if self .tensor_parallel_degree > 1 :
11171121 strategy .tensor_parallel_configs = {"tensor_init_seed" : self .seed }
11181122
1119- if " " in self .tensor_parallel_config :
1120- mp_config = set (self .tensor_parallel_config .split (" " ))
1121- else :
1122- mp_config = set (self .tensor_parallel_config .split ("," ))
1123+ mp_config = split_parallel_config (self .tensor_parallel_config )
11231124
11241125 for x in mp_config :
11251126 if len (x ) > 0 :
@@ -1225,10 +1226,8 @@ def is_segment_parallel_supported():
12251226 strategy .hybrid_configs = hybrid_configs
12261227
12271228 if self .sharding_parallel_degree > 1 :
1228- if " " in self .sharding_parallel_config :
1229- sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1230- else :
1231- sharding_parallel_config = set (self .sharding_parallel_config .split ("," ))
1229+ sharding_parallel_config = split_parallel_config (self .sharding_parallel_config )
1230+
12321231 for x in sharding_parallel_config :
12331232 if len (x ) > 0 :
12341233 if x not in [
@@ -1384,10 +1383,7 @@ def is_segment_parallel_supported():
13841383
13851384 # navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
13861385 if self .pipeline_parallel_degree > 1 and self .gradient_accumulation_steps > 1 :
1387- if " " in self .pipeline_parallel_config :
1388- pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1389- else :
1390- pipeline_parallel_config = set (self .pipeline_parallel_config .split ("," ))
1386+ pipeline_parallel_config = split_parallel_config (self .pipeline_parallel_config )
13911387 for x in pipeline_parallel_config :
13921388 if len (x ) > 0 :
13931389 if x not in [
@@ -1436,11 +1432,7 @@ def is_segment_parallel_supported():
14361432
14371433 if self .tensor_parallel_degree > 1 :
14381434 mp_optimization = strategy .mp_optimization
1439-
1440- if " " in self .tensor_parallel_config :
1441- mp_config = set (self .tensor_parallel_config .split (" " ))
1442- else :
1443- mp_config = set (self .tensor_parallel_config .split ("," ))
1435+ mp_config = split_parallel_config (self .tensor_parallel_config )
14441436
14451437 for x in mp_config :
14461438 if len (x ) > 0 :
@@ -1473,10 +1465,7 @@ def is_segment_parallel_supported():
14731465 elif ShardingOption .FULL_SHARD in self .sharding :
14741466 sharding .stage = 3
14751467
1476- if " " in self .sharding_parallel_config :
1477- sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1478- else :
1479- sharding_parallel_config = set (self .sharding_parallel_config .split ("," ))
1468+ sharding_parallel_config = split_parallel_config (self .sharding_parallel_config )
14801469 for x in sharding_parallel_config :
14811470 if len (x ) > 0 :
14821471 if x not in [
0 commit comments