Skip to content

Commit a1378a3

Browse files
committed
[DistDataloader] Update implementation, add nested.py
1 parent edc04f3 commit a1378a3

File tree

4 files changed

+148
-190
lines changed

4 files changed

+148
-190
lines changed

paddlenlp/data/dist_dataloader.py

Lines changed: 59 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import numpy as np
1615
import paddle
1716
from paddle.distributed import fleet
1817

1918
from paddlenlp.utils.log import logger
19+
from paddlenlp.utils.nested import (
20+
nested_broadcast_tensor,
21+
nested_copy_place,
22+
nested_empty_tensor,
23+
nested_reduce_tensor,
24+
)
2025

2126
_MAX_DATA_DIM = 64
2227

@@ -78,10 +83,6 @@ def __init__(
7883
sharding_rank = self._hcg.get_sharding_parallel_rank()
7984
self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0)
8085

81-
# When needed other data types, we can modify dtype_list.
82-
self.dtype_list = [paddle.int64, paddle.float32, paddle.int32]
83-
self._data_keys_list, self._data_keys_size = None, None
84-
8586
if self._need_data:
8687
self._dataloader = paddle.io.DataLoader(
8788
dataset,
@@ -137,127 +138,63 @@ def _init_dataloader_comm_group(self):
137138
def __iter__(self):
138139
return self
139140

140-
def __next__(self):
141-
data_keys_size = [0 for i in range(len(self.dtype_list))]
142-
if self._need_data:
143-
data = next(self._dataloader_iter)
144-
data_keys = list(data.keys())
145-
146-
for key in data_keys:
147-
if data[key].dtype not in self.dtype_list:
148-
raise ValueError(
149-
f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}"
150-
)
151-
152-
data_list, data_keys_list = [], []
153-
for i, dtype in enumerate(self.dtype_list):
154-
data_list.append([data[key] for key in data_keys if data[key].dtype == dtype])
155-
data_keys_list.append([key for key in data_keys if data[key].dtype == dtype])
156-
data_keys_size = [len(keys) for keys in data_keys_list]
157-
158-
# Broadcast data keys size.
159-
if self._data_keys_size is None:
160-
if self.mp_group.nranks > 1 and self.pp_rank == 0:
161-
paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group)
162-
if self._pp_data_group is not None:
163-
paddle.distributed.broadcast_object_list(
164-
data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
165-
)
166-
self._data_keys_size = data_keys_size
167-
168-
if not self._need_data:
169-
data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]
170-
171-
# Broadcast data keys name.
172-
if self._data_keys_list is None:
173-
if self.mp_group.nranks > 1 and self.pp_rank == 0:
174-
paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group)
175-
if self._pp_data_group is not None:
176-
paddle.distributed.broadcast_object_list(
177-
data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group
178-
)
179-
self._data_keys_list = data_keys_list
180-
181-
# Broadcast data.
182-
if not self._need_data:
183-
data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]
184-
185-
if self.mp_group.nranks > 1 and self.pp_rank == 0:
186-
for i, dtype in enumerate(self.dtype_list):
187-
if self._data_keys_size[i] > 0:
188-
data_list[i] = broadcast_data_list(
189-
data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank
141+
def _broadcast_data(self, data):
142+
process_rank = paddle.distributed.get_rank()
143+
if self.mp_group.nranks > 1:
144+
if process_rank == self.mp_src_rank:
145+
fake_data = [nested_reduce_tensor(data)]
146+
else:
147+
if data is not None:
148+
logger.warning(
149+
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
190150
)
191-
151+
fake_data = [None]
192152
if self._pp_data_group is not None:
193-
# Note(daisimng): In last stage of pp, we don't need input_ids.
194-
# It will be removed in future.
195-
for i, dtype in enumerate(self.dtype_list):
196-
if self._data_keys_size[i] > 0:
197-
data_list[i] = broadcast_data_list(
198-
data_list[i],
199-
dtype,
200-
self.pp_rank,
201-
self._pp_data_group,
202-
self._pp_data_group.ranks[0],
153+
if process_rank == self._pp_data_group.ranks[0]:
154+
fake_data = [nested_reduce_tensor(data)]
155+
else:
156+
if data is not None:
157+
logger.warning(
158+
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
203159
)
160+
fake_data = [None]
161+
if self.mp_group.nranks > 1 and self.pp_rank == 0:
162+
paddle.distributed.broadcast_object_list(
163+
fake_data,
164+
src=self.mp_src_rank,
165+
group=self.mp_group,
166+
)
167+
if self._pp_data_group is not None:
168+
paddle.disibributed.broadcast_object_list(
169+
fake_data,
170+
src=self._pp_data_group.ranks[0],
171+
group=self._pp_data_group,
172+
)
173+
fake_data = fake_data[0]
204174

205-
out_data = {}
206-
for keys, datas in zip(self._data_keys_list, data_list):
207-
out_data.update([(k, d) for k, d in zip(keys, datas)])
208-
209-
return out_data
210-
211-
212-
def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0):
213-
"""
214-
Broadcast data from src_rank to all ranks in comm_group.
215-
"""
216-
# Move to GPU and broadcast.
217-
size_cpu = []
218-
if comm_rank == 0:
219-
for data in data_list:
220-
size_cpu.append(len(data.shape))
221-
size_cpu += data.shape
222-
size_cpu = size_cpu + [0] * (_MAX_DATA_DIM - len(size_cpu))
223-
size_cuda = paddle.to_tensor(size_cpu)
224-
paddle.distributed.broadcast(size_cuda, src_rank, group=comm_group).wait()
225-
226-
size_cpu = size_cuda.tolist()
227-
i = 0
228-
numel = 0
229-
sizes = []
230-
while size_cpu[i] > 0:
231-
rank = size_cpu[i]
232-
this_size = size_cpu[i + 1 : i + 1 + rank]
233-
numel += int(np.prod(this_size))
234-
sizes.append(this_size)
235-
i += rank + 1
236-
237-
if comm_rank == 0:
238-
assert data.dtype == datatype, "input has data type {} which " "is different than {}".format(
239-
data.dtype, datatype
240-
)
241-
if paddle.is_compiled_with_cuda():
242-
data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0)
243-
else:
244-
data_b = paddle.concat([d.reshape([-1]) for d in data_list], 0)
245-
246-
assert numel == sum([d.numel().item() for d in data_list]), (numel, [d.numel().item() for d in data_list])
247-
else:
248-
if paddle.is_compiled_with_cuda():
249-
data_b = paddle.empty([numel], dtype=datatype).cuda()
250-
else:
251-
data_b = paddle.empty([numel], dtype=datatype)
175+
if self.mp_group.nranks > 1:
176+
if process_rank != self.mp_src_rank:
177+
data = nested_empty_tensor(fake_data)
178+
if self._pp_data_group is not None:
179+
if process_rank != self._pp_data_group.ranks[0]:
180+
data = nested_empty_tensor(fake_data)
181+
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
182+
if self.mp_group.nranks > 1 and self.pp_rank == 0:
183+
data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group)
184+
if self._pp_data_group is not None:
185+
data = nested_broadcast_tensor(data, src=self._pp_data_group.ranks[0], group=self._pp_data_group)
252186

253-
# Broadcast
254-
paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait()
187+
if data is None:
188+
raise StopIteration
255189

256-
ret = []
257-
offset = 0
258-
for size in sizes:
259-
numel = int(np.prod(size))
260-
ret.append(data_b[offset : offset + numel].reshape(size))
261-
offset += numel
190+
return data
262191

263-
return ret
192+
def __next__(self):
193+
data = None
194+
if self._need_data:
195+
try:
196+
data = next(self._dataloader_iter)
197+
except:
198+
pass
199+
data = self._broadcast_data(data)
200+
return data

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
SAFE_WEIGHTS_NAME,
6363
)
6464
from paddlenlp.utils.log import logger
65+
from paddlenlp.utils.nested import nested_copy, nested_copy_place
6566

6667
if is_safetensors_available():
6768
from safetensors import safe_open
@@ -1876,26 +1877,6 @@ def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys):
18761877
return new_actions
18771878

18781879

1879-
def nested_copy(inputs):
1880-
if isinstance(inputs, dict):
1881-
outputs = {}
1882-
for key in list(inputs.keys()):
1883-
outputs[key] = nested_copy(inputs[key])
1884-
return outputs
1885-
return inputs
1886-
1887-
1888-
def nested_copy_place(inputs, place=None, blocking=False):
1889-
if isinstance(inputs, dict):
1890-
outputs = {}
1891-
for key in list(inputs.keys()):
1892-
outputs[key] = nested_copy_place(inputs[key], place, blocking)
1893-
return outputs
1894-
if isinstance(inputs, paddle.Tensor):
1895-
inputs = inputs if inputs.place == place else inputs._copy_to(place, blocking)
1896-
return inputs
1897-
1898-
18991880
def flatten_list(nested_list):
19001881
flattened_list = []
19011882
for item in nested_list:

paddlenlp/trainer/utils/helper.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
# This file is modified from
1717
# https://github.com/huggingface/transformers/blob/main/src/transformers
1818

19-
import collections
20-
import copy
2119
import os
2220
from typing import Any, Optional
2321

@@ -27,6 +25,11 @@
2725
from paddle.distributed import fleet
2826

2927
from paddlenlp.utils.log import logger
28+
from paddlenlp.utils.nested import (
29+
nested_broadcast_tensor,
30+
nested_empty_tensor,
31+
nested_reduce_tensor,
32+
)
3033

3134
__all__ = [
3235
"distributed_concat",
@@ -180,52 +183,6 @@ def distributed_file(filename):
180183
return filename
181184

182185

183-
TensorHolder = collections.namedtuple("TensorHolder", ["shape", "dtype", "name"])
184-
185-
186-
def nested_reduce_tensor(tensor):
187-
if isinstance(tensor, dict):
188-
# copy tensor since it will be inplace modified dict
189-
tensor = copy.copy(tensor)
190-
for key in list(tensor.keys()):
191-
tensor[key] = nested_reduce_tensor(tensor[key])
192-
if isinstance(tensor, (tuple, list)):
193-
return type(tensor)(nested_reduce_tensor(t) for t in tensor)
194-
195-
if isinstance(tensor, paddle.Tensor):
196-
return TensorHolder(tensor.shape, tensor.dtype, tensor.name)
197-
198-
return tensor
199-
200-
201-
def nested_empty_tensor(tensor):
202-
if isinstance(tensor, dict):
203-
for key in list(tensor.keys()):
204-
tensor[key] = nested_empty_tensor(tensor[key])
205-
if isinstance(tensor, list):
206-
return type(tensor)(nested_empty_tensor(t) for t in tensor)
207-
208-
# TensorHolder is tuple
209-
if isinstance(tensor, TensorHolder):
210-
t = paddle.empty(tensor.shape, dtype=tensor.dtype, name=tensor.name)
211-
t.name = tensor.name
212-
return t
213-
214-
return tensor
215-
216-
217-
def nested_broadcast_tensor(tensor, src=0, group=None):
218-
if isinstance(tensor, dict):
219-
for key in list(tensor.keys()):
220-
tensor[key] = nested_broadcast_tensor(tensor[key], src=src, group=group)
221-
if isinstance(tensor, list):
222-
return type(tensor)(nested_broadcast_tensor(t, src=src, group=group) for t in tensor)
223-
224-
if isinstance(tensor, paddle.Tensor):
225-
paddle.distributed.broadcast(tensor, src=src, group=group, sync_op=True)
226-
return tensor
227-
228-
229186
def broadcast_dp_optimizer(state_dict):
230187
if paddle.distributed.get_world_size() <= 1:
231188
return state_dict

0 commit comments

Comments
 (0)