|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import numpy as np |
16 | 15 | import paddle |
17 | 16 | from paddle.distributed import fleet |
18 | 17 |
|
19 | 18 | from paddlenlp.utils.log import logger |
| 19 | +from paddlenlp.utils.nested import ( |
| 20 | + nested_broadcast_tensor, |
| 21 | + nested_copy_place, |
| 22 | + nested_empty_tensor, |
| 23 | + nested_reduce_tensor, |
| 24 | +) |
20 | 25 |
|
21 | 26 | _MAX_DATA_DIM = 64 |
22 | 27 |
|
@@ -78,10 +83,6 @@ def __init__( |
78 | 83 | sharding_rank = self._hcg.get_sharding_parallel_rank() |
79 | 84 | self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0) |
80 | 85 |
|
81 | | - # When needed other data types, we can modify dtype_list. |
82 | | - self.dtype_list = [paddle.int64, paddle.float32, paddle.int32] |
83 | | - self._data_keys_list, self._data_keys_size = None, None |
84 | | - |
85 | 86 | if self._need_data: |
86 | 87 | self._dataloader = paddle.io.DataLoader( |
87 | 88 | dataset, |
@@ -137,127 +138,63 @@ def _init_dataloader_comm_group(self): |
137 | 138 | def __iter__(self): |
138 | 139 | return self |
139 | 140 |
|
140 | | - def __next__(self): |
141 | | - data_keys_size = [0 for i in range(len(self.dtype_list))] |
142 | | - if self._need_data: |
143 | | - data = next(self._dataloader_iter) |
144 | | - data_keys = list(data.keys()) |
145 | | - |
146 | | - for key in data_keys: |
147 | | - if data[key].dtype not in self.dtype_list: |
148 | | - raise ValueError( |
149 | | - f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}" |
150 | | - ) |
151 | | - |
152 | | - data_list, data_keys_list = [], [] |
153 | | - for i, dtype in enumerate(self.dtype_list): |
154 | | - data_list.append([data[key] for key in data_keys if data[key].dtype == dtype]) |
155 | | - data_keys_list.append([key for key in data_keys if data[key].dtype == dtype]) |
156 | | - data_keys_size = [len(keys) for keys in data_keys_list] |
157 | | - |
158 | | - # Broadcast data keys size. |
159 | | - if self._data_keys_size is None: |
160 | | - if self.mp_group.nranks > 1 and self.pp_rank == 0: |
161 | | - paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group) |
162 | | - if self._pp_data_group is not None: |
163 | | - paddle.distributed.broadcast_object_list( |
164 | | - data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group |
165 | | - ) |
166 | | - self._data_keys_size = data_keys_size |
167 | | - |
168 | | - if not self._need_data: |
169 | | - data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size] |
170 | | - |
171 | | - # Broadcast data keys name. |
172 | | - if self._data_keys_list is None: |
173 | | - if self.mp_group.nranks > 1 and self.pp_rank == 0: |
174 | | - paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group) |
175 | | - if self._pp_data_group is not None: |
176 | | - paddle.distributed.broadcast_object_list( |
177 | | - data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group |
178 | | - ) |
179 | | - self._data_keys_list = data_keys_list |
180 | | - |
181 | | - # Broadcast data. |
182 | | - if not self._need_data: |
183 | | - data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size] |
184 | | - |
185 | | - if self.mp_group.nranks > 1 and self.pp_rank == 0: |
186 | | - for i, dtype in enumerate(self.dtype_list): |
187 | | - if self._data_keys_size[i] > 0: |
188 | | - data_list[i] = broadcast_data_list( |
189 | | - data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank |
| 141 | + def _broadcast_data(self, data): |
| 142 | + process_rank = paddle.distributed.get_rank() |
| 143 | + if self.mp_group.nranks > 1: |
| 144 | + if process_rank == self.mp_src_rank: |
| 145 | + fake_data = [nested_reduce_tensor(data)] |
| 146 | + else: |
| 147 | + if data is not None: |
| 148 | + logger.warning( |
| 149 | + f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict." |
190 | 150 | ) |
191 | | - |
| 151 | + fake_data = [None] |
192 | 152 | if self._pp_data_group is not None: |
193 | | - # Note(daisimng): In last stage of pp, we don't need input_ids. |
194 | | - # It will be removed in future. |
195 | | - for i, dtype in enumerate(self.dtype_list): |
196 | | - if self._data_keys_size[i] > 0: |
197 | | - data_list[i] = broadcast_data_list( |
198 | | - data_list[i], |
199 | | - dtype, |
200 | | - self.pp_rank, |
201 | | - self._pp_data_group, |
202 | | - self._pp_data_group.ranks[0], |
| 153 | + if process_rank == self._pp_data_group.ranks[0]: |
| 154 | + fake_data = [nested_reduce_tensor(data)] |
| 155 | + else: |
| 156 | + if data is not None: |
| 157 | + logger.warning( |
| 158 | + f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict." |
203 | 159 | ) |
| 160 | + fake_data = [None] |
| 161 | + if self.mp_group.nranks > 1 and self.pp_rank == 0: |
| 162 | + paddle.distributed.broadcast_object_list( |
| 163 | + fake_data, |
| 164 | + src=self.mp_src_rank, |
| 165 | + group=self.mp_group, |
| 166 | + ) |
| 167 | + if self._pp_data_group is not None: |
| 168 | + paddle.disibributed.broadcast_object_list( |
| 169 | + fake_data, |
| 170 | + src=self._pp_data_group.ranks[0], |
| 171 | + group=self._pp_data_group, |
| 172 | + ) |
| 173 | + fake_data = fake_data[0] |
204 | 174 |
|
205 | | - out_data = {} |
206 | | - for keys, datas in zip(self._data_keys_list, data_list): |
207 | | - out_data.update([(k, d) for k, d in zip(keys, datas)]) |
208 | | - |
209 | | - return out_data |
210 | | - |
211 | | - |
212 | | -def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0): |
213 | | - """ |
214 | | - Broadcast data from src_rank to all ranks in comm_group. |
215 | | - """ |
216 | | - # Move to GPU and broadcast. |
217 | | - size_cpu = [] |
218 | | - if comm_rank == 0: |
219 | | - for data in data_list: |
220 | | - size_cpu.append(len(data.shape)) |
221 | | - size_cpu += data.shape |
222 | | - size_cpu = size_cpu + [0] * (_MAX_DATA_DIM - len(size_cpu)) |
223 | | - size_cuda = paddle.to_tensor(size_cpu) |
224 | | - paddle.distributed.broadcast(size_cuda, src_rank, group=comm_group).wait() |
225 | | - |
226 | | - size_cpu = size_cuda.tolist() |
227 | | - i = 0 |
228 | | - numel = 0 |
229 | | - sizes = [] |
230 | | - while size_cpu[i] > 0: |
231 | | - rank = size_cpu[i] |
232 | | - this_size = size_cpu[i + 1 : i + 1 + rank] |
233 | | - numel += int(np.prod(this_size)) |
234 | | - sizes.append(this_size) |
235 | | - i += rank + 1 |
236 | | - |
237 | | - if comm_rank == 0: |
238 | | - assert data.dtype == datatype, "input has data type {} which " "is different than {}".format( |
239 | | - data.dtype, datatype |
240 | | - ) |
241 | | - if paddle.is_compiled_with_cuda(): |
242 | | - data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0) |
243 | | - else: |
244 | | - data_b = paddle.concat([d.reshape([-1]) for d in data_list], 0) |
245 | | - |
246 | | - assert numel == sum([d.numel().item() for d in data_list]), (numel, [d.numel().item() for d in data_list]) |
247 | | - else: |
248 | | - if paddle.is_compiled_with_cuda(): |
249 | | - data_b = paddle.empty([numel], dtype=datatype).cuda() |
250 | | - else: |
251 | | - data_b = paddle.empty([numel], dtype=datatype) |
| 175 | + if self.mp_group.nranks > 1: |
| 176 | + if process_rank != self.mp_src_rank: |
| 177 | + data = nested_empty_tensor(fake_data) |
| 178 | + if self._pp_data_group is not None: |
| 179 | + if process_rank != self._pp_data_group.ranks[0]: |
| 180 | + data = nested_empty_tensor(fake_data) |
| 181 | + data = nested_copy_place(data, place=paddle.framework._current_expected_place()) |
| 182 | + if self.mp_group.nranks > 1 and self.pp_rank == 0: |
| 183 | + data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group) |
| 184 | + if self._pp_data_group is not None: |
| 185 | + data = nested_broadcast_tensor(data, src=self._pp_data_group.ranks[0], group=self._pp_data_group) |
252 | 186 |
|
253 | | - # Broadcast |
254 | | - paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait() |
| 187 | + if data is None: |
| 188 | + raise StopIteration |
255 | 189 |
|
256 | | - ret = [] |
257 | | - offset = 0 |
258 | | - for size in sizes: |
259 | | - numel = int(np.prod(size)) |
260 | | - ret.append(data_b[offset : offset + numel].reshape(size)) |
261 | | - offset += numel |
| 190 | + return data |
262 | 191 |
|
263 | | - return ret |
| 192 | + def __next__(self): |
| 193 | + data = None |
| 194 | + if self._need_data: |
| 195 | + try: |
| 196 | + data = next(self._dataloader_iter) |
| 197 | + except: |
| 198 | + pass |
| 199 | + data = self._broadcast_data(data) |
| 200 | + return data |
0 commit comments