1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import numpy as np
1516import paddle
1617from paddle .distributed import fleet
1718
1819from paddlenlp .utils .log import logger
19- from paddlenlp .utils .nested import (
20- nested_broadcast_tensor ,
21- nested_copy_place ,
22- nested_empty_tensor ,
23- nested_reduce_tensor ,
24- )
20+
21+ _MAX_DATA_DIM = 64
2522
2623
2724class DummyDataset (paddle .io .Dataset ):
@@ -71,10 +68,8 @@ def __init__(
7168 # Init pp data comm group.
7269 if self ._hcg .get_pipe_parallel_world_size () > 1 :
7370 self ._pp_data_group = self ._init_dataloader_comm_group ()
74- self ._pp_group = self ._hcg .get_pipe_parallel_group ()
7571 else :
7672 self ._pp_data_group = None
77- self ._pp_group = None
7873
7974 self .mp_group = self ._hcg .get_model_parallel_group ()
8075 self .mp_rank = self ._hcg .get_model_parallel_rank ()
@@ -85,6 +80,10 @@ def __init__(
8580 sharding_rank = self ._hcg .get_sharding_parallel_rank ()
8681 self ._need_data = (self .mp_rank == 0 ) and (self .pp_rank == 0 )
8782
83+ # When needed other data types, we can modify dtype_list.
84+ self .dtype_list = [paddle .int64 , paddle .float32 , paddle .int32 ]
85+ self ._data_keys_list , self ._data_keys_size = None , None
86+
8887 if self ._need_data :
8988 self ._dataloader = paddle .io .DataLoader (
9089 dataset ,
@@ -130,7 +129,11 @@ def _init_dataloader_comm_group(self):
130129 parallel_groups = topo .get_comm_list ("pipe" )
131130
132131 for group in parallel_groups :
133- ranks = [group [0 ], group [- 1 ]]
132+ if not self .eval :
133+ # only first rank and last rank
134+ ranks = [group [0 ], group [- 1 ]]
135+ else :
136+ ranks = group
134137 comm_group = paddle .distributed .new_group (ranks = ranks )
135138 if paddle .distributed .get_rank () in ranks :
136139 parallel_comm_group = comm_group
@@ -139,70 +142,138 @@ def _init_dataloader_comm_group(self):
139142 def __iter__ (self ):
140143 return self
141144
142- def _broadcast_data (self , data ):
143- process_rank = paddle .distributed .get_rank ()
144- if self .mp_group .nranks > 1 :
145- if process_rank == self .mp_src_rank :
146- fake_data = [nested_reduce_tensor (data )]
147- else :
148- if data is not None :
149- logger .warning (
150- f"Your local rank { paddle .distributed .get_rank ()} are forbidden to have a state_dict."
151- )
152- fake_data = [None ]
153- if self ._pp_group is not None :
154- if process_rank == self ._pp_group .ranks [0 ]:
155- fake_data = [nested_reduce_tensor (data )]
156- else :
157- if data is not None :
158- logger .warning (
159- f"Your local rank { paddle .distributed .get_rank ()} are forbidden to have a state_dict."
160- )
161- fake_data = [None ]
162- if self .mp_group .nranks > 1 and self .pp_rank == 0 :
163- paddle .distributed .broadcast_object_list (
164- fake_data ,
165- src = self .mp_src_rank ,
166- group = self .mp_group ,
167- )
168- if self ._pp_group is not None :
169- paddle .distributed .broadcast_object_list (
170- fake_data ,
171- src = self ._pp_group .ranks [0 ],
172- group = self ._pp_group ,
173- )
174- else :
175- fake_data = [None ]
145+ def __next__ (self ):
146+ data_keys_size = [0 for i in range (len (self .dtype_list ))]
147+ stop_flag = False
148+ if self ._need_data :
149+ try :
150+ data = next (self ._dataloader_iter )
151+ except :
152+ stop_flag = True
176153
177- fake_data = fake_data [0 ]
178- if fake_data is None :
154+ if not stop_flag :
155+ data_keys = list (data .keys ())
156+
157+ for key in data_keys :
158+ if data [key ].dtype not in self .dtype_list :
159+ raise ValueError (
160+ f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: { data [key ].dtype } "
161+ )
162+
163+ data_list , data_keys_list = [], []
164+ for i , dtype in enumerate (self .dtype_list ):
165+ data_list .append ([data [key ] for key in data_keys if data [key ].dtype == dtype ])
166+ data_keys_list .append ([key for key in data_keys if data [key ].dtype == dtype ])
167+ data_keys_size = [len (keys ) for keys in data_keys_list ]
168+
169+ stop_flag = paddle .to_tensor ([stop_flag ], dtype = "bool" )
170+ paddle .distributed .all_reduce (stop_flag , op = paddle .distributed .ReduceOp .MAX )
171+ if stop_flag .item ():
179172 raise StopIteration
180173
181- dst_pp_group = self ._pp_group if self .eval else self ._pp_data_group
182- if self .mp_group .nranks > 1 :
183- if process_rank != self .mp_src_rank :
184- data = nested_empty_tensor (fake_data )
185- if dst_pp_group is not None :
186- if process_rank != dst_pp_group .ranks [0 ]:
187- data = nested_empty_tensor (fake_data )
174+ # Broadcast data keys size.
175+ if self ._data_keys_size is None :
176+ if self .mp_group .nranks > 1 and self .pp_rank == 0 :
177+ paddle .distributed .broadcast_object_list (data_keys_size , src = self .mp_src_rank , group = self .mp_group )
178+ if self ._pp_data_group is not None :
179+ paddle .distributed .broadcast_object_list (
180+ data_keys_size , src = self ._pp_data_group .ranks [0 ], group = self ._pp_data_group
181+ )
182+ self ._data_keys_size = data_keys_size
183+
184+ if not self ._need_data :
185+ data_keys_list = [[None for i in range (keys_size )] for keys_size in self ._data_keys_size ]
186+
187+ # Broadcast data keys name.
188+ if self ._data_keys_list is None :
189+ if self .mp_group .nranks > 1 and self .pp_rank == 0 :
190+ paddle .distributed .broadcast_object_list (data_keys_list , src = self .mp_src_rank , group = self .mp_group )
191+ if self ._pp_data_group is not None :
192+ paddle .distributed .broadcast_object_list (
193+ data_keys_list , src = self ._pp_data_group .ranks [0 ], group = self ._pp_data_group
194+ )
195+ self ._data_keys_list = data_keys_list
196+
197+ # Broadcast data.
198+ if not self ._need_data :
199+ data_list = [[None for i in range (keys_size )] for keys_size in self ._data_keys_size ]
188200
189201 if self .mp_group .nranks > 1 and self .pp_rank == 0 :
190- data = nested_broadcast_tensor (data , src = self .mp_src_rank , group = self .mp_group )
191- if dst_pp_group is not None :
192- data = nested_broadcast_tensor (data , src = dst_pp_group .ranks [0 ], group = dst_pp_group )
193- # for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
194- if data is None :
195- data = {}
202+ for i , dtype in enumerate (self .dtype_list ):
203+ if self ._data_keys_size [i ] > 0 :
204+ data_list [i ] = broadcast_data_list (
205+ data_list [i ], dtype , self .mp_rank , self .mp_group , self .mp_src_rank
206+ )
196207
197- return data
208+ if self ._pp_data_group is not None :
209+ # Note(daisimng): In last stage of pp, we don't need input_ids.
210+ # It will be removed in future.
211+ for i , dtype in enumerate (self .dtype_list ):
212+ if self ._data_keys_size [i ] > 0 :
213+ data_list [i ] = broadcast_data_list (
214+ data_list [i ],
215+ dtype ,
216+ self .pp_rank ,
217+ self ._pp_data_group ,
218+ self ._pp_data_group .ranks [0 ],
219+ )
198220
199- def __next__ (self ):
200- data = None
201- if self ._need_data :
202- try :
203- data = next (self ._dataloader_iter )
204- data = nested_copy_place (data , place = paddle .framework ._current_expected_place ())
205- except :
206- pass
207- data = self ._broadcast_data (data )
208- return data
221+ out_data = {}
222+ for keys , datas in zip (self ._data_keys_list , data_list ):
223+ out_data .update ([(k , d ) for k , d in zip (keys , datas )])
224+
225+ return out_data
226+
227+
228+ def broadcast_data_list (data_list , datatype , comm_rank = 0 , comm_group = None , src_rank = 0 ):
229+ """
230+ Broadcast data from src_rank to all ranks in comm_group.
231+ """
232+ # Move to GPU and broadcast.
233+ size_cpu = []
234+ if comm_rank == 0 :
235+ for data in data_list :
236+ size_cpu .append (len (data .shape ))
237+ size_cpu += data .shape
238+ size_cpu = size_cpu + [0 ] * (_MAX_DATA_DIM - len (size_cpu ))
239+ size_cuda = paddle .to_tensor (size_cpu )
240+ paddle .distributed .broadcast (size_cuda , src_rank , group = comm_group ).wait ()
241+
242+ size_cpu = size_cuda .tolist ()
243+ i = 0
244+ numel = 0
245+ sizes = []
246+ while size_cpu [i ] > 0 :
247+ rank = size_cpu [i ]
248+ this_size = size_cpu [i + 1 : i + 1 + rank ]
249+ numel += int (np .prod (this_size ))
250+ sizes .append (this_size )
251+ i += rank + 1
252+
253+ if comm_rank == 0 :
254+ assert data .dtype == datatype , "input has data type {} which " "is different than {}" .format (
255+ data .dtype , datatype
256+ )
257+ if paddle .is_compiled_with_cuda ():
258+ data_b = paddle .concat ([d .cuda ().reshape ([- 1 ]) for d in data_list ], 0 )
259+ else :
260+ data_b = paddle .concat ([d .reshape ([- 1 ]) for d in data_list ], 0 )
261+
262+ assert numel == sum ([d .numel ().item () for d in data_list ]), (numel , [d .numel ().item () for d in data_list ])
263+ else :
264+ if paddle .is_compiled_with_cuda ():
265+ data_b = paddle .empty ([numel ], dtype = datatype ).cuda ()
266+ else :
267+ data_b = paddle .empty ([numel ], dtype = datatype )
268+
269+ # Broadcast
270+ paddle .distributed .broadcast (data_b , src_rank , group = comm_group ).wait ()
271+
272+ ret = []
273+ offset = 0
274+ for size in sizes :
275+ numel = int (np .prod (size ))
276+ ret .append (data_b [offset : offset + numel ].reshape (size ))
277+ offset += numel
278+
279+ return ret
0 commit comments