|
18 | 18 |
|
19 | 19 | import collections |
20 | 20 | import contextlib |
21 | | -import copy |
22 | 21 | import inspect |
23 | 22 | import math |
24 | | -import multiprocessing |
25 | 23 | import os |
26 | 24 | import random |
27 | 25 | import re |
|
148 | 146 | ) |
149 | 147 | from .training_args import TrainingArguments |
150 | 148 | from .utils import reshard as reshard_util |
| 149 | +from .utils.async_save import AsyncSaver |
151 | 150 | from .utils.helper import ( # nested_truncate, |
152 | 151 | broadcast_dp_optimizer, |
153 | 152 | broadcast_moe_optimizer, |
|
190 | 189 |
|
191 | 190 | __all__ = ["Trainer"] |
192 | 191 |
|
193 | | -async_save_queue = [] |
194 | | -g_cpu_optimizer_state_dict = {} |
195 | | - |
196 | | - |
197 | | -def _save_func(obj, name_mapping, path, saved_signal_path, protocol): |
198 | | - if isinstance(obj, dict): |
199 | | - for k, v in obj.items(): |
200 | | - if k == "master_weights" and isinstance(v, dict): |
201 | | - for kk, vv in v.items(): |
202 | | - if isinstance(vv, paddle.Tensor): |
203 | | - vv.name = name_mapping["master_weights"][kk] |
204 | | - else: |
205 | | - if k in name_mapping and isinstance(v, paddle.Tensor): |
206 | | - v.name = name_mapping[k] |
207 | | - |
208 | | - paddle.save(obj, path, protocol) |
209 | | - # dump savd_siganl |
210 | | - with open(saved_signal_path, mode="w+") as f: |
211 | | - f.write("1") |
212 | | - |
213 | | - |
214 | | -def check_exitcode(task): |
215 | | - exitcode = task.exitcode |
216 | | - if exitcode != 0: |
217 | | - print(f"Error: save ckpt process failed with exitcode {exitcode}!!!") |
218 | | - |
219 | | - |
220 | | -def clear_async_save_task_queue(): |
221 | | - """ |
222 | | - wait until all async save task to be done. |
223 | | - """ |
224 | | - while len(async_save_queue) > 0: |
225 | | - task = async_save_queue.pop() |
226 | | - if task and task.is_alive(): |
227 | | - task.join(timeout=60) |
228 | | - if task.is_alive(): |
229 | | - logger.error("Error: save ckpt process timeout!!!") |
230 | | - async_save_queue.append(task) |
231 | | - else: |
232 | | - check_exitcode(task) |
233 | | - else: |
234 | | - check_exitcode(task) |
235 | | - |
236 | | - |
237 | | -def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4): |
238 | | - global g_cpu_optimizer_state_dict |
239 | | - g_cpu_optimizer_state_dict.clear() |
240 | | - name_mapping = {"master_weights": {}} |
241 | | - for k, v in optimizer_state_dict.items(): |
242 | | - if k == "master_weights": |
243 | | - g_cpu_optimizer_state_dict[k] = {} |
244 | | - for kk, vv in v.items(): |
245 | | - g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory() |
246 | | - name_mapping[k][kk] = vv.name |
247 | | - elif k == "LR_Scheduler": |
248 | | - g_cpu_optimizer_state_dict[k] = copy.deepcopy(v) |
249 | | - else: |
250 | | - g_cpu_optimizer_state_dict[k] = v.pin_memory() |
251 | | - name_mapping[k] = v.name |
252 | | - paddle.device.synchronize() |
253 | | - clear_async_save_task_queue() |
254 | | - |
255 | | - attempt = 0 |
256 | | - ctx = multiprocessing.get_context("spawn") |
257 | | - |
258 | | - def start_process(): |
259 | | - nonlocal attempt |
260 | | - try: |
261 | | - p = ctx.Process( |
262 | | - target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol) |
263 | | - ) |
264 | | - p.start() |
265 | | - return p |
266 | | - except Exception as e: |
267 | | - print(f"Attempt {attempt + 1} failed with error: {e}") |
268 | | - attempt += 1 |
269 | | - time.sleep(1) |
270 | | - return start_process() |
271 | | - |
272 | | - p = start_process() |
273 | | - async_save_queue.append(p) |
274 | | - |
275 | 192 |
|
276 | 193 | class Trainer: |
277 | 194 | """ |
@@ -439,6 +356,8 @@ def __init__( |
439 | 356 |
|
440 | 357 | self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save |
441 | 358 | self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load |
| 359 | + if self.args.use_async_save: |
| 360 | + self._async_optimizer_saver = AsyncSaver() |
442 | 361 |
|
443 | 362 | if args.max_steps > 0: |
444 | 363 | logger.info("max_steps is given, it will override any value given in num_train_epochs") |
@@ -2323,9 +2242,6 @@ def _save_checkpoint(self, model, metrics=None): |
2323 | 2242 | # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" |
2324 | 2243 | self.runtime_timer.start("checkpoint saving time") |
2325 | 2244 |
|
2326 | | - if self.args.use_async_save: |
2327 | | - clear_async_save_task_queue() |
2328 | | - |
2329 | 2245 | # Save model checkpoint |
2330 | 2246 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" |
2331 | 2247 |
|
@@ -2368,10 +2284,8 @@ def _save_checkpoint(self, model, metrics=None): |
2368 | 2284 | save_path = os.path.join(output_dir, optimizer_name) |
2369 | 2285 | if self.args.use_async_save: |
2370 | 2286 | assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC" |
2371 | | - async_save_optimizer( |
2372 | | - state_dict, |
2373 | | - save_path, |
2374 | | - saved_signal_path=saved_signal_path, |
| 2287 | + self._async_optimizer_saver.run( |
| 2288 | + state_dict, save_path, saved_signal_path=saved_signal_path |
2375 | 2289 | ) |
2376 | 2290 | else: |
2377 | 2291 | self._save_ckpt_func(state_dict, save_path) |
|
0 commit comments