@@ -1576,20 +1576,20 @@ def pipeline_parallel_rank(self):
15761576 else :
15771577 return 0
15781578
1579+ def _format_name (self , prefix , rank , degree ):
1580+ size = max (2 , len (str (degree )))
1581+ return f"{ prefix } { rank :0>{size }d} "
1582+
15791583 @property
15801584 def optimizer_name_suffix (self ):
15811585 if self .use_hybrid_parallel :
15821586 name = []
15831587 if self .tensor_parallel_degree > 1 :
1584- assert self .tensor_parallel_degree < 100 , "tensor parallel degree should be less than 100."
1585- name .append (f"tp{ self .tensor_parallel_rank :0>2d} " )
1588+ name .append (self ._format_name ("tp" , self .tensor_parallel_rank , self .tensor_parallel_degree ))
15861589 if self .pipeline_parallel_degree > 1 :
1587- assert self .pipeline_parallel_degree < 100 , "pipeline parallel degree should be less than 100."
1588- name .append (f"pp{ self .pipeline_parallel_rank :0>2d} " )
1590+ name .append (self ._format_name ("pp" , self .pipeline_parallel_rank , self .pipeline_parallel_degree ))
15891591 if self .sharding_parallel_degree > 1 :
1590- assert self .sharding_parallel_degree < 100 , "sharding parallel degree should be less than 100."
1591- name .append (f"shard{ self .sharding_parallel_rank :0>2d} " )
1592-
1592+ name .append (self ._format_name ("shard" , self .sharding_parallel_rank , self .sharding_parallel_degree ))
15931593 return "_" .join (name )
15941594 else :
15951595 return None
@@ -1599,11 +1599,9 @@ def weight_name_suffix(self):
15991599 if self .use_hybrid_parallel :
16001600 name = []
16011601 if self .tensor_parallel_degree > 1 :
1602- assert self .tensor_parallel_rank < 100 , "tensor parallel rank should be less than 100."
1603- name .append (f"tp{ self .tensor_parallel_rank :0>2d} " )
1602+ name .append (self ._format_name ("tp" , self .tensor_parallel_rank , self .tensor_parallel_degree ))
16041603 if self .pipeline_parallel_degree > 1 :
1605- assert self .pipeline_parallel_degree < 100 , "tensor parallel rank should be less than 100."
1606- name .append (f"pp{ self .pipeline_parallel_rank :0>2d} " )
1604+ name .append (self ._format_name ("pp" , self .pipeline_parallel_rank , self .pipeline_parallel_degree ))
16071605 return "_" .join (name )
16081606
16091607 else :
@@ -1613,20 +1611,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
16131611 if self .use_hybrid_parallel :
16141612 name = []
16151613 if self .tensor_parallel_degree > 1 :
1616- assert self .tensor_parallel_rank < 100 , "tensor parallel rank should be less than 100."
1617- name .append (f"tp{ self .tensor_parallel_rank :0>2d} " )
1614+ name .append (self ._format_name ("tp" , self .tensor_parallel_rank , self .tensor_parallel_degree ))
16181615 if self .pipeline_parallel_degree > 1 :
16191616 if pp_id is None :
16201617 pp_id = self .pipeline_parallel_rank
16211618 assert isinstance (pp_id , int )
1622- assert pp_id < 100 , "pp_id should be less than 100."
1623- name .append (f"pp{ pp_id :0>2d} " )
1619+ name .append (self ._format_name ("pp" , pp_id , self .pipeline_parallel_degree ))
16241620 if self .sharding_parallel_degree > 1 :
16251621 if shard_id is None :
16261622 shard_id = self .sharding_parallel_rank
16271623 assert isinstance (shard_id , int )
1628- assert shard_id < 100 , "shard_id should be less than 100."
1629- name .append (f"shard{ shard_id :0>2d} " )
1624+ name .append (self ._format_name ("shard" , shard_id , self .sharding_parallel_degree ))
16301625 return "_" .join (name )
16311626 else :
16321627 return None
0 commit comments