Skip to content

Commit f658fa7

Browse files
authored
fix sharding <100 limitation (#8146)
1 parent 1ef7503 commit f658fa7

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,20 +1576,20 @@ def pipeline_parallel_rank(self):
15761576
else:
15771577
return 0
15781578

1579+
def _format_name(self, prefix, rank, degree):
1580+
size = max(2, len(str(degree)))
1581+
return f"{prefix}{rank:0>{size}d}"
1582+
15791583
@property
15801584
def optimizer_name_suffix(self):
15811585
if self.use_hybrid_parallel:
15821586
name = []
15831587
if self.tensor_parallel_degree > 1:
1584-
assert self.tensor_parallel_degree < 100, "tensor parallel degree should be less than 100."
1585-
name.append(f"tp{self.tensor_parallel_rank:0>2d}")
1588+
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
15861589
if self.pipeline_parallel_degree > 1:
1587-
assert self.pipeline_parallel_degree < 100, "pipeline parallel degree should be less than 100."
1588-
name.append(f"pp{self.pipeline_parallel_rank:0>2d}")
1590+
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
15891591
if self.sharding_parallel_degree > 1:
1590-
assert self.sharding_parallel_degree < 100, "sharding parallel degree should be less than 100."
1591-
name.append(f"shard{self.sharding_parallel_rank:0>2d}")
1592-
1592+
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
15931593
return "_".join(name)
15941594
else:
15951595
return None
@@ -1599,11 +1599,9 @@ def weight_name_suffix(self):
15991599
if self.use_hybrid_parallel:
16001600
name = []
16011601
if self.tensor_parallel_degree > 1:
1602-
assert self.tensor_parallel_rank < 100, "tensor parallel rank should be less than 100."
1603-
name.append(f"tp{self.tensor_parallel_rank:0>2d}")
1602+
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
16041603
if self.pipeline_parallel_degree > 1:
1605-
assert self.pipeline_parallel_degree < 100, "tensor parallel rank should be less than 100."
1606-
name.append(f"pp{self.pipeline_parallel_rank:0>2d}")
1604+
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
16071605
return "_".join(name)
16081606

16091607
else:
@@ -1613,20 +1611,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
16131611
if self.use_hybrid_parallel:
16141612
name = []
16151613
if self.tensor_parallel_degree > 1:
1616-
assert self.tensor_parallel_rank < 100, "tensor parallel rank should be less than 100."
1617-
name.append(f"tp{self.tensor_parallel_rank:0>2d}")
1614+
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
16181615
if self.pipeline_parallel_degree > 1:
16191616
if pp_id is None:
16201617
pp_id = self.pipeline_parallel_rank
16211618
assert isinstance(pp_id, int)
1622-
assert pp_id < 100, "pp_id should be less than 100."
1623-
name.append(f"pp{pp_id:0>2d}")
1619+
name.append(self._format_name("pp", pp_id, self.pipeline_parallel_degree))
16241620
if self.sharding_parallel_degree > 1:
16251621
if shard_id is None:
16261622
shard_id = self.sharding_parallel_rank
16271623
assert isinstance(shard_id, int)
1628-
assert shard_id < 100, "shard_id should be less than 100."
1629-
name.append(f"shard{shard_id:0>2d}")
1624+
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
16301625
return "_".join(name)
16311626
else:
16321627
return None

0 commit comments

Comments
 (0)