Skip to content

Commit dabda13

Browse files
committed
add second try
1 parent 84b4bf7 commit dabda13

File tree

1 file changed

+140
-69
lines changed

1 file changed

+140
-69
lines changed

paddlenlp/data/dist_dataloader.py

Lines changed: 140 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
1516
import paddle
1617
from paddle.distributed import fleet
1718

1819
from paddlenlp.utils.log import logger
19-
from paddlenlp.utils.nested import (
20-
nested_broadcast_tensor,
21-
nested_copy_place,
22-
nested_empty_tensor,
23-
nested_reduce_tensor,
24-
)
20+
21+
_MAX_DATA_DIM = 64
2522

2623

2724
class DummyDataset(paddle.io.Dataset):
@@ -71,10 +68,8 @@ def __init__(
7168
# Init pp data comm group.
7269
if self._hcg.get_pipe_parallel_world_size() > 1:
7370
self._pp_data_group = self._init_dataloader_comm_group()
74-
self._pp_group = self._hcg.get_pipe_parallel_group()
7571
else:
7672
self._pp_data_group = None
77-
self._pp_group = None
7873

7974
self.mp_group = self._hcg.get_model_parallel_group()
8075
self.mp_rank = self._hcg.get_model_parallel_rank()
@@ -85,6 +80,10 @@ def __init__(
8580
sharding_rank = self._hcg.get_sharding_parallel_rank()
8681
self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0)
8782

83+
# When needed other data types, we can modify dtype_list.
84+
self.dtype_list = [paddle.int64, paddle.float32, paddle.int32]
85+
self._data_keys_list, self._data_keys_size = None, None
86+
8887
if self._need_data:
8988
self._dataloader = paddle.io.DataLoader(
9089
dataset,
@@ -130,7 +129,11 @@ def _init_dataloader_comm_group(self):
130129
parallel_groups = topo.get_comm_list("pipe")
131130

132131
for group in parallel_groups:
133-
ranks = [group[0], group[-1]]
132+
if not self.eval:
133+
# only first rank and last rank
134+
ranks = [group[0], group[-1]]
135+
else:
136+
ranks = group
134137
comm_group = paddle.distributed.new_group(ranks=ranks)
135138
if paddle.distributed.get_rank() in ranks:
136139
parallel_comm_group = comm_group
@@ -139,70 +142,138 @@ def _init_dataloader_comm_group(self):
139142
def __iter__(self):
140143
return self
141144

142-
def _broadcast_data(self, data):
143-
process_rank = paddle.distributed.get_rank()
144-
if self.mp_group.nranks > 1:
145-
if process_rank == self.mp_src_rank:
146-
fake_data = [nested_reduce_tensor(data)]
147-
else:
148-
if data is not None:
149-
logger.warning(
150-
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
151-
)
152-
fake_data = [None]
153-
if self._pp_group is not None:
154-
if process_rank == self._pp_group.ranks[0]:
155-
fake_data = [nested_reduce_tensor(data)]
156-
else:
157-
if data is not None:
158-
logger.warning(
159-
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
160-
)
161-
fake_data = [None]
162-
if self.mp_group.nranks > 1 and self.pp_rank == 0:
163-
paddle.distributed.broadcast_object_list(
164-
fake_data,
165-
src=self.mp_src_rank,
166-
group=self.mp_group,
167-
)
168-
if self._pp_group is not None:
169-
paddle.distributed.broadcast_object_list(
170-
fake_data,
171-
src=self._pp_group.ranks[0],
172-
group=self._pp_group,
173-
)
174-
else:
175-
fake_data = [None]
145+
def __next__(self):
146+
data_keys_size = [0 for i in range(len(self.dtype_list))]
147+
stop_flag = False
148+
if self._need_data:
149+
try:
150+
data = next(self._dataloader_iter)
151+
except:
152+
stop_flag = True
176153

177-
fake_data = fake_data[0]
178-
if fake_data is None:
154+
if not stop_flag:
155+
data_keys = list(data.keys())
156+
157+
for key in data_keys:
158+
if data[key].dtype not in self.dtype_list:
159+
raise ValueError(
160+
f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}"
161+
)
162+
163+
data_list, data_keys_list = [], []
164+
for i, dtype in enumerate(self.dtype_list):
165+
data_list.append([data[key] for key in data_keys if data[key].dtype == dtype])
166+
data_keys_list.append([key for key in data_keys if data[key].dtype == dtype])
167+
data_keys_size = [len(keys) for keys in data_keys_list]
168+
169+
stop_flag = paddle.to_tensor([stop_flag], dtype="bool")
170+
paddle.distributed.all_reduce(stop_flag, op=paddle.distributed.ReduceOp.MAX)
171+
if stop_flag.item():
179172
raise StopIteration
180173

181-
dst_pp_group = self._pp_group if self.eval else self._pp_data_group
182-
if self.mp_group.nranks > 1:
183-
if process_rank != self.mp_src_rank:
184-
data = nested_empty_tensor(fake_data)
185-
if dst_pp_group is not None:
186-
if process_rank != dst_pp_group.ranks[0]:
187-
data = nested_empty_tensor(fake_data)
174+
# Broadcast data keys size.
175+
if self._data_keys_size is None:
176+
if self.mp_group.nranks > 1 and self.pp_rank == 0:
177+
paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group)
178+
if self._pp_data_group is not None:
179+
paddle.distributed.broadcast_object_list(
180+
data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
181+
)
182+
self._data_keys_size = data_keys_size
183+
184+
if not self._need_data:
185+
data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]
186+
187+
# Broadcast data keys name.
188+
if self._data_keys_list is None:
189+
if self.mp_group.nranks > 1 and self.pp_rank == 0:
190+
paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group)
191+
if self._pp_data_group is not None:
192+
paddle.distributed.broadcast_object_list(
193+
data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group
194+
)
195+
self._data_keys_list = data_keys_list
196+
197+
# Broadcast data.
198+
if not self._need_data:
199+
data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]
188200

189201
if self.mp_group.nranks > 1 and self.pp_rank == 0:
190-
data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group)
191-
if dst_pp_group is not None:
192-
data = nested_broadcast_tensor(data, src=dst_pp_group.ranks[0], group=dst_pp_group)
193-
# for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
194-
if data is None:
195-
data = {}
202+
for i, dtype in enumerate(self.dtype_list):
203+
if self._data_keys_size[i] > 0:
204+
data_list[i] = broadcast_data_list(
205+
data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank
206+
)
196207

197-
return data
208+
if self._pp_data_group is not None:
209+
# Note(daisimng): In last stage of pp, we don't need input_ids.
210+
# It will be removed in future.
211+
for i, dtype in enumerate(self.dtype_list):
212+
if self._data_keys_size[i] > 0:
213+
data_list[i] = broadcast_data_list(
214+
data_list[i],
215+
dtype,
216+
self.pp_rank,
217+
self._pp_data_group,
218+
self._pp_data_group.ranks[0],
219+
)
198220

199-
def __next__(self):
200-
data = None
201-
if self._need_data:
202-
try:
203-
data = next(self._dataloader_iter)
204-
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
205-
except:
206-
pass
207-
data = self._broadcast_data(data)
208-
return data
221+
out_data = {}
222+
for keys, datas in zip(self._data_keys_list, data_list):
223+
out_data.update([(k, d) for k, d in zip(keys, datas)])
224+
225+
return out_data
226+
227+
228+
def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0):
229+
"""
230+
Broadcast data from src_rank to all ranks in comm_group.
231+
"""
232+
# Move to GPU and broadcast.
233+
size_cpu = []
234+
if comm_rank == 0:
235+
for data in data_list:
236+
size_cpu.append(len(data.shape))
237+
size_cpu += data.shape
238+
size_cpu = size_cpu + [0] * (_MAX_DATA_DIM - len(size_cpu))
239+
size_cuda = paddle.to_tensor(size_cpu)
240+
paddle.distributed.broadcast(size_cuda, src_rank, group=comm_group).wait()
241+
242+
size_cpu = size_cuda.tolist()
243+
i = 0
244+
numel = 0
245+
sizes = []
246+
while size_cpu[i] > 0:
247+
rank = size_cpu[i]
248+
this_size = size_cpu[i + 1 : i + 1 + rank]
249+
numel += int(np.prod(this_size))
250+
sizes.append(this_size)
251+
i += rank + 1
252+
253+
if comm_rank == 0:
254+
assert data.dtype == datatype, "input has data type {} which " "is different than {}".format(
255+
data.dtype, datatype
256+
)
257+
if paddle.is_compiled_with_cuda():
258+
data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0)
259+
else:
260+
data_b = paddle.concat([d.reshape([-1]) for d in data_list], 0)
261+
262+
assert numel == sum([d.numel().item() for d in data_list]), (numel, [d.numel().item() for d in data_list])
263+
else:
264+
if paddle.is_compiled_with_cuda():
265+
data_b = paddle.empty([numel], dtype=datatype).cuda()
266+
else:
267+
data_b = paddle.empty([numel], dtype=datatype)
268+
269+
# Broadcast
270+
paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait()
271+
272+
ret = []
273+
offset = 0
274+
for size in sizes:
275+
numel = int(np.prod(size))
276+
ret.append(data_b[offset : offset + numel].reshape(size))
277+
offset += numel
278+
279+
return ret

0 commit comments

Comments
 (0)