2121
2222class Topology :
2323 def __init__ (
24- self , device_rank , world_size , dp_degree = None , pp_degree = 1 , sharding_degree = 1 , mp_degree = 1 , sep_degree = 1
24+ self ,
25+ device_rank ,
26+ world_size ,
27+ dp_degree = None ,
28+ pp_degree = 1 ,
29+ sharding_degree = 1 ,
30+ mp_degree = 1 ,
31+ sep_degree = 1 ,
32+ order = ["dp" , "pp" , "sharding" , "mp" , "sep" ],
2533 ):
26- arr = np .arange (0 , dp_degree * pp_degree * sharding_degree * mp_degree * sep_degree ).reshape (
27- [dp_degree , pp_degree , sharding_degree , mp_degree , sep_degree ]
28- )
29-
30- dp_rank , pp_rank , sharding_rank , mp_rank , sep_rank = np .where (arr == device_rank )
31- dp_rank = dp_rank [0 ]
32- pp_rank = pp_rank [0 ]
33- sharding_rank = sharding_rank [0 ]
34- mp_rank = mp_rank [0 ]
35- sep_rank = sep_rank [0 ]
36-
37- self .world = GroupInfo (size = world_size , rank = device_rank , world = list (range (0 , world_size )))
34+ assert set (order ) == {"dp" , "pp" , "sharding" , "mp" , "sep" }, f"Illegal order : { order } "
35+ self .order = order
3836
39- sep_world = arr [dp_rank , pp_rank , sharding_rank , mp_rank , :]
40- self .sep_info = GroupInfo (size = len (sep_world ), rank = sep_rank , world = sep_world .tolist ())
37+ degree_map = {
38+ "dp" : dp_degree ,
39+ "pp" : pp_degree ,
40+ "sharding" : sharding_degree ,
41+ "mp" : mp_degree ,
42+ "sep" : sep_degree ,
43+ }
44+ shape = [degree_map [key ] for key in self .order ]
4145
42- mp_world = arr [ dp_rank , pp_rank , sharding_rank , :, sep_rank ]
43- self . mp_info = GroupInfo ( size = len ( mp_world ), rank = mp_rank , world = mp_world . tolist ())
46+ arr = np . arange ( 0 , dp_degree * pp_degree * sharding_degree * mp_degree * sep_degree ). reshape ( shape )
47+ ranks = [ rank [ 0 ] for rank in np . where ( arr == device_rank )]
4448
45- sharding_world = arr [dp_rank , pp_rank , :, mp_rank , sep_rank ]
46- self .sharding_info = GroupInfo (size = len (sharding_world ), rank = sharding_rank , world = sharding_world .tolist ())
47-
48- pp_world = arr [dp_rank , :, sharding_rank , mp_rank , sep_rank ]
49- self .pp_info = GroupInfo (size = len (pp_world ), rank = pp_rank , world = pp_world .tolist ())
50-
51- dp_world = arr [:, pp_rank , sharding_rank , mp_rank , sep_rank ]
52- self .dp_info = GroupInfo (size = len (dp_world ), rank = dp_rank , world = dp_world .tolist ())
49+ self .world = GroupInfo (size = world_size , rank = device_rank , world = list (range (0 , world_size )))
50+ worlds = []
51+ for i in range (len (ranks )):
52+ indexs = tuple (ranks [:i ] + [slice (None )] + ranks [(i + 1 ) :])
53+ worlds .append (arr [indexs ])
54+
55+ for i , key in enumerate (self .order ):
56+ if key == "dp" :
57+ self .dp_info = GroupInfo (size = len (worlds [i ]), rank = ranks [i ], world = worlds [i ].tolist ())
58+ elif key == "pp" :
59+ self .pp_info = GroupInfo (size = len (worlds [i ]), rank = ranks [i ], world = worlds [i ].tolist ())
60+ elif key == "sharding" :
61+ self .sharding_info = GroupInfo (size = len (worlds [i ]), rank = ranks [i ], world = worlds [i ].tolist ())
62+ elif key == "mp" :
63+ self .mp_info = GroupInfo (size = len (worlds [i ]), rank = ranks [i ], world = worlds [i ].tolist ())
64+ elif key == "sep" :
65+ self .sep_info = GroupInfo (size = len (worlds [i ]), rank = ranks [i ], world = worlds [i ].tolist ())
5366
5467 self .is_last = self .pp_info .rank == self .pp_info .size - 1
5568
5669 data_arr = np .arange (0 , dp_degree * sharding_degree ).reshape ([dp_degree , sharding_degree ])
57- data_arr = np . expand_dims ( data_arr , axis = 1 ). repeat ( pp_degree , axis = 1 )
58- data_arr = np . expand_dims ( data_arr , axis = 3 ). repeat ( mp_degree , axis = 3 )
59- data_arr = np .expand_dims (data_arr , axis = 4 ).repeat (sep_degree , axis = 4 )
70+ for i , key in enumerate ( self . order ):
71+ if key != "dp" and key != "sharding" :
72+ data_arr = np .expand_dims (data_arr , axis = i ).repeat (degree_map [ key ] , axis = i )
6073
6174 self .data_info = GroupInfo (
6275 size = int (self .dp_info .size * self .sharding_info .size ),
@@ -68,4 +81,4 @@ def __init__(
6881 self .data_inner_times = self .world .size // self .data_info .size
6982
7083 def __repr__ (self ):
71- return f"dp_info:\n \t { self .dp_info } , \n pp_info:\n \t { self .pp_info } , \n sharding_info:\n \t { self .sharding_info } , \n mp_info:\n \t { self .mp_info } , \n sep_info:\n \t { self .sep_info } \n data_info:\n \t { self .data_info } "
84+ return f"dp_info:\n \t { self .dp_info } , \n pp_info:\n \t { self .pp_info } , \n sharding_info:\n \t { self .sharding_info } , \n mp_info:\n \t { self .mp_info } , \n sep_info:\n \t { self .sep_info } , \n data_info:\n \t { self .data_info } , \n order: \n \t { self . order } "
0 commit comments