1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import numpy as np
1615import paddle
1716from paddle .distributed import fleet
1817
1918from paddlenlp .utils .log import logger
20-
21- _MAX_DATA_DIM = 64
19+ from paddlenlp .utils .nested import (
20+ nested_broadcast_tensor ,
21+ nested_copy_place ,
22+ nested_empty_tensor ,
23+ nested_reduce_tensor ,
24+ )
2225
2326
2427class DummyDataset (paddle .io .Dataset ):
@@ -53,6 +56,7 @@ def __init__(
5356 timeout = 0 ,
5457 worker_init_fn = None ,
5558 persistent_workers = False ,
59+ eval = False ,
5660 ):
5761
5862 if dataset is None :
@@ -62,12 +66,15 @@ def __init__(
6266 super ().__init__ (dataset = dataset , batch_sampler = batch_sampler , collate_fn = collate_fn , num_workers = num_workers )
6367
6468 self ._hcg = fleet .get_hybrid_communicate_group ()
69+ self .eval = eval
6570
6671 # Init pp data comm group.
6772 if self ._hcg .get_pipe_parallel_world_size () > 1 :
6873 self ._pp_data_group = self ._init_dataloader_comm_group ()
74+ self ._pp_group = self ._hcg .get_pipe_parallel_group ()
6975 else :
7076 self ._pp_data_group = None
77+ self ._pp_group = None
7178
7279 self .mp_group = self ._hcg .get_model_parallel_group ()
7380 self .mp_rank = self ._hcg .get_model_parallel_rank ()
@@ -78,10 +85,6 @@ def __init__(
7885 sharding_rank = self ._hcg .get_sharding_parallel_rank ()
7986 self ._need_data = (self .mp_rank == 0 ) and (self .pp_rank == 0 )
8087
81- # When needed other data types, we can modify dtype_list.
82- self .dtype_list = [paddle .int64 , paddle .float32 , paddle .int32 ]
83- self ._data_keys_list , self ._data_keys_size = None , None
84-
8588 if self ._need_data :
8689 self ._dataloader = paddle .io .DataLoader (
8790 dataset ,
@@ -127,7 +130,6 @@ def _init_dataloader_comm_group(self):
127130 parallel_groups = topo .get_comm_list ("pipe" )
128131
129132 for group in parallel_groups :
130- # only first rank and last rank
131133 ranks = [group [0 ], group [- 1 ]]
132134 comm_group = paddle .distributed .new_group (ranks = ranks )
133135 if paddle .distributed .get_rank () in ranks :
@@ -137,127 +139,68 @@ def _init_dataloader_comm_group(self):
137139 def __iter__ (self ):
138140 return self
139141
140- def __next__ (self ):
141- data_keys_size = [0 for i in range (len (self .dtype_list ))]
142- if self ._need_data :
143- data = next (self ._dataloader_iter )
144- data_keys = list (data .keys ())
145-
146- for key in data_keys :
147- if data [key ].dtype not in self .dtype_list :
148- raise ValueError (
149- f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: { data [key ].dtype } "
142+ def _broadcast_data (self , data ):
143+ process_rank = paddle .distributed .get_rank ()
144+ if self .mp_group .nranks > 1 :
145+ if process_rank == self .mp_src_rank :
146+ fake_data = [nested_reduce_tensor (data )]
147+ else :
148+ if data is not None :
149+ logger .warning (
150+ f"Your local rank { paddle .distributed .get_rank ()} are forbidden to have a state_dict."
150151 )
151-
152- data_list , data_keys_list = [], []
153- for i , dtype in enumerate (self .dtype_list ):
154- data_list .append ([data [key ] for key in data_keys if data [key ].dtype == dtype ])
155- data_keys_list .append ([key for key in data_keys if data [key ].dtype == dtype ])
156- data_keys_size = [len (keys ) for keys in data_keys_list ]
157-
158- # Broadcast data keys size.
159- if self ._data_keys_size is None :
160- if self .mp_group .nranks > 1 and self .pp_rank == 0 :
161- paddle .distributed .broadcast_object_list (data_keys_size , src = self .mp_src_rank , group = self .mp_group )
162- if self ._pp_data_group is not None :
163- paddle .distributed .broadcast_object_list (
164- data_keys_size , src = self ._pp_data_group .ranks [0 ], group = self ._pp_data_group
165- )
166- self ._data_keys_size = data_keys_size
167-
168- if not self ._need_data :
169- data_keys_list = [[None for i in range (keys_size )] for keys_size in self ._data_keys_size ]
170-
171- # Broadcast data keys name.
172- if self ._data_keys_list is None :
173- if self .mp_group .nranks > 1 and self .pp_rank == 0 :
174- paddle .distributed .broadcast_object_list (data_keys_list , src = self .mp_src_rank , group = self .mp_group )
175- if self ._pp_data_group is not None :
176- paddle .distributed .broadcast_object_list (
177- data_keys_list , src = self ._pp_data_group .ranks [0 ], group = self ._pp_data_group
178- )
179- self ._data_keys_list = data_keys_list
180-
181- # Broadcast data.
182- if not self ._need_data :
183- data_list = [[None for i in range (keys_size )] for keys_size in self ._data_keys_size ]
184-
185- if self .mp_group .nranks > 1 and self .pp_rank == 0 :
186- for i , dtype in enumerate (self .dtype_list ):
187- if self ._data_keys_size [i ] > 0 :
188- data_list [i ] = broadcast_data_list (
189- data_list [i ], dtype , self .mp_rank , self .mp_group , self .mp_src_rank
152+ fake_data = [None ]
153+ if self ._pp_group is not None :
154+ if process_rank == self ._pp_group .ranks [0 ]:
155+ fake_data = [nested_reduce_tensor (data )]
156+ else :
157+ if data is not None :
158+ logger .warning (
159+ f"Your local rank { paddle .distributed .get_rank ()} are forbidden to have a state_dict."
190160 )
161+ fake_data = [None ]
162+ if self .mp_group .nranks > 1 and self .pp_rank == 0 :
163+ paddle .distributed .broadcast_object_list (
164+ fake_data ,
165+ src = self .mp_src_rank ,
166+ group = self .mp_group ,
167+ )
168+ if self ._pp_group is not None :
169+ paddle .distributed .broadcast_object_list (
170+ fake_data ,
171+ src = self ._pp_group .ranks [0 ],
172+ group = self ._pp_group ,
173+ )
191174
192- if self ._pp_data_group is not None :
193- # Note(daisimng): In last stage of pp, we don't need input_ids.
194- # It will be removed in future.
195- for i , dtype in enumerate (self .dtype_list ):
196- if self ._data_keys_size [i ] > 0 :
197- data_list [i ] = broadcast_data_list (
198- data_list [i ],
199- dtype ,
200- self .pp_rank ,
201- self ._pp_data_group ,
202- self ._pp_data_group .ranks [0 ],
203- )
204-
205- out_data = {}
206- for keys , datas in zip (self ._data_keys_list , data_list ):
207- out_data .update ([(k , d ) for k , d in zip (keys , datas )])
208-
209- return out_data
210-
211-
212- def broadcast_data_list (data_list , datatype , comm_rank = 0 , comm_group = None , src_rank = 0 ):
213- """
214- Broadcast data from src_rank to all ranks in comm_group.
215- """
216- # Move to GPU and broadcast.
217- size_cpu = []
218- if comm_rank == 0 :
219- for data in data_list :
220- size_cpu .append (len (data .shape ))
221- size_cpu += data .shape
222- size_cpu = size_cpu + [0 ] * (_MAX_DATA_DIM - len (size_cpu ))
223- size_cuda = paddle .to_tensor (size_cpu )
224- paddle .distributed .broadcast (size_cuda , src_rank , group = comm_group ).wait ()
225-
226- size_cpu = size_cuda .tolist ()
227- i = 0
228- numel = 0
229- sizes = []
230- while size_cpu [i ] > 0 :
231- rank = size_cpu [i ]
232- this_size = size_cpu [i + 1 : i + 1 + rank ]
233- numel += int (np .prod (this_size ))
234- sizes .append (this_size )
235- i += rank + 1
236-
237- if comm_rank == 0 :
238- assert data .dtype == datatype , "input has data type {} which " "is different than {}" .format (
239- data .dtype , datatype
240- )
241- if paddle .is_compiled_with_cuda ():
242- data_b = paddle .concat ([d .cuda ().reshape ([- 1 ]) for d in data_list ], 0 )
243- else :
244- data_b = paddle .concat ([d .reshape ([- 1 ]) for d in data_list ], 0 )
175+ fake_data = fake_data [0 ]
176+ if fake_data is None :
177+ raise StopIteration
245178
246- assert numel == sum ([d .numel ().item () for d in data_list ]), (numel , [d .numel ().item () for d in data_list ])
247- else :
248- if paddle .is_compiled_with_cuda ():
249- data_b = paddle .empty ([numel ], dtype = datatype ).cuda ()
250- else :
251- data_b = paddle .empty ([numel ], dtype = datatype )
179+ dst_pp_group = self ._pp_group if self .eval else self ._pp_data_group
180+ if self .mp_group .nranks > 1 :
181+ if process_rank != self .mp_src_rank :
182+ data = nested_empty_tensor (fake_data )
183+ if dst_pp_group is not None :
184+ if process_rank != dst_pp_group .ranks [0 ]:
185+ data = nested_empty_tensor (fake_data )
252186
253- # Broadcast
254- paddle .distributed .broadcast (data_b , src_rank , group = comm_group ).wait ()
187+ if self .mp_group .nranks > 1 and self .pp_rank == 0 :
188+ data = nested_broadcast_tensor (data , src = self .mp_src_rank , group = self .mp_group )
189+ if dst_pp_group is not None :
190+ data = nested_broadcast_tensor (data , src = dst_pp_group .ranks [0 ], group = dst_pp_group )
191+ # for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
192+ if data is None :
193+ data = {}
255194
256- ret = []
257- offset = 0
258- for size in sizes :
259- numel = int (np .prod (size ))
260- ret .append (data_b [offset : offset + numel ].reshape (size ))
261- offset += numel
195+ return data
262196
263- return ret
197+ def __next__ (self ):
198+ data = None
199+ if self ._need_data :
200+ try :
201+ data = next (self ._dataloader_iter )
202+ data = nested_copy_place (data , place = paddle .framework ._current_expected_place ())
203+ except :
204+ pass
205+ data = self ._broadcast_data (data )
206+ return data
0 commit comments