@@ -1039,10 +1039,10 @@ def __post_init__(self):
10391039 strategy = fleet .DistributedStrategy ()
10401040 assert self .data_parallel_config == "" , "data_parallle_config is not supported in hybrid parallel"
10411041 if self .pipeline_parallel_degree > 1 :
1042- if " " in self .pipeline_parallel_config :
1043- pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1044- else :
1042+ if "," in self .pipeline_parallel_config :
10451043 pipeline_parallel_config = set (self .pipeline_parallel_config .split ("," ))
1044+ else :
1045+ pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
10461046 for x in pipeline_parallel_config :
10471047 if len (x ) > 0 :
10481048 if x not in [
@@ -1116,10 +1116,10 @@ def __post_init__(self):
11161116 if self .tensor_parallel_degree > 1 :
11171117 strategy .tensor_parallel_configs = {"tensor_init_seed" : self .seed }
11181118
1119- if " " in self .tensor_parallel_config :
1120- mp_config = set (self .tensor_parallel_config .split (" " ))
1121- else :
1119+ if "," in self .tensor_parallel_config :
11221120 mp_config = set (self .tensor_parallel_config .split ("," ))
1121+ else :
1122+ mp_config = set (self .tensor_parallel_config .split (" " ))
11231123
11241124 for x in mp_config :
11251125 if len (x ) > 0 :
@@ -1225,10 +1225,11 @@ def is_segment_parallel_supported():
12251225 strategy .hybrid_configs = hybrid_configs
12261226
12271227 if self .sharding_parallel_degree > 1 :
1228- if " " in self .sharding_parallel_config :
1229- sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1230- else :
1228+ if "," in self .sharding_parallel_config :
12311229 sharding_parallel_config = set (self .sharding_parallel_config .split ("," ))
1230+ else :
1231+ sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1232+
12321233 for x in sharding_parallel_config :
12331234 if len (x ) > 0 :
12341235 if x not in [
@@ -1384,10 +1385,10 @@ def is_segment_parallel_supported():
13841385
13851386 # navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
13861387 if self .pipeline_parallel_degree > 1 and self .gradient_accumulation_steps > 1 :
1387- if " " in self .pipeline_parallel_config :
1388- pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1389- else :
1388+ if "," in self .pipeline_parallel_config :
13901389 pipeline_parallel_config = set (self .pipeline_parallel_config .split ("," ))
1390+ else :
1391+ pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
13911392 for x in pipeline_parallel_config :
13921393 if len (x ) > 0 :
13931394 if x not in [
@@ -1437,10 +1438,10 @@ def is_segment_parallel_supported():
14371438 if self .tensor_parallel_degree > 1 :
14381439 mp_optimization = strategy .mp_optimization
14391440
1440- if " " in self .tensor_parallel_config :
1441- mp_config = set (self .tensor_parallel_config .split (" " ))
1442- else :
1441+ if "," in self .tensor_parallel_config :
14431442 mp_config = set (self .tensor_parallel_config .split ("," ))
1443+ else :
1444+ mp_config = set (self .tensor_parallel_config .split (" " ))
14441445
14451446 for x in mp_config :
14461447 if len (x ) > 0 :
@@ -1473,10 +1474,10 @@ def is_segment_parallel_supported():
14731474 elif ShardingOption .FULL_SHARD in self .sharding :
14741475 sharding .stage = 3
14751476
1476- if " " in self .sharding_parallel_config :
1477- sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1478- else :
1477+ if "," in self .sharding_parallel_config :
14791478 sharding_parallel_config = set (self .sharding_parallel_config .split ("," ))
1479+ else :
1480+ sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
14801481 for x in sharding_parallel_config :
14811482 if len (x ) > 0 :
14821483 if x not in [
0 commit comments