|
18 | 18 |
|
19 | 19 | import collections |
20 | 20 | import contextlib |
21 | | -import copy |
22 | 21 | import inspect |
23 | 22 | import math |
24 | | -import multiprocessing |
25 | 23 | import os |
26 | 24 | import random |
27 | 25 | import re |
|
149 | 147 | ) |
150 | 148 | from .training_args import TrainingArguments |
151 | 149 | from .utils import reshard as reshard_util |
| 150 | +from .utils.async_save import AsyncSaver |
152 | 151 | from .utils.helper import ( # nested_truncate, |
153 | 152 | broadcast_dp_optimizer, |
154 | 153 | broadcast_moe_optimizer, |
|
191 | 190 |
|
192 | 191 | __all__ = ["Trainer"] |
193 | 192 |
|
194 | | -async_save_queue = [] |
195 | | -g_cpu_optimizer_state_dict = {} |
196 | | - |
197 | | - |
198 | | -def _save_func(obj, name_mapping, path, saved_signal_path, protocol): |
199 | | - if isinstance(obj, dict): |
200 | | - for k, v in obj.items(): |
201 | | - if k == "master_weights" and isinstance(v, dict): |
202 | | - for kk, vv in v.items(): |
203 | | - if isinstance(vv, paddle.Tensor): |
204 | | - vv.name = name_mapping["master_weights"][kk] |
205 | | - else: |
206 | | - if k in name_mapping and isinstance(v, paddle.Tensor): |
207 | | - v.name = name_mapping[k] |
208 | | - |
209 | | - paddle.save(obj, path, protocol) |
210 | | - # dump savd_siganl |
211 | | - with open(saved_signal_path, mode="w+") as f: |
212 | | - f.write("1") |
213 | | - |
214 | | - |
215 | | -def check_exitcode(task): |
216 | | - exitcode = task.exitcode |
217 | | - if exitcode != 0: |
218 | | - print(f"Error: save ckpt process failed with exitcode {exitcode}!!!") |
219 | | - |
220 | | - |
221 | | -def clear_async_save_task_queue(): |
222 | | - """ |
223 | | - wait until all async save task to be done. |
224 | | - """ |
225 | | - while len(async_save_queue) > 0: |
226 | | - task = async_save_queue.pop() |
227 | | - if task and task.is_alive(): |
228 | | - task.join(timeout=60) |
229 | | - if task.is_alive(): |
230 | | - logger.error("Error: save ckpt process timeout!!!") |
231 | | - async_save_queue.append(task) |
232 | | - else: |
233 | | - check_exitcode(task) |
234 | | - else: |
235 | | - check_exitcode(task) |
236 | | - |
237 | | - |
238 | | -def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4): |
239 | | - global g_cpu_optimizer_state_dict |
240 | | - g_cpu_optimizer_state_dict.clear() |
241 | | - name_mapping = {"master_weights": {}} |
242 | | - for k, v in optimizer_state_dict.items(): |
243 | | - if k == "master_weights": |
244 | | - g_cpu_optimizer_state_dict[k] = {} |
245 | | - for kk, vv in v.items(): |
246 | | - g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory() |
247 | | - name_mapping[k][kk] = vv.name |
248 | | - elif k == "LR_Scheduler": |
249 | | - g_cpu_optimizer_state_dict[k] = copy.deepcopy(v) |
250 | | - else: |
251 | | - g_cpu_optimizer_state_dict[k] = v.pin_memory() |
252 | | - name_mapping[k] = v.name |
253 | | - paddle.device.synchronize() |
254 | | - clear_async_save_task_queue() |
255 | | - |
256 | | - attempt = 0 |
257 | | - ctx = multiprocessing.get_context("spawn") |
258 | | - |
259 | | - def start_process(): |
260 | | - nonlocal attempt |
261 | | - try: |
262 | | - p = ctx.Process( |
263 | | - target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol) |
264 | | - ) |
265 | | - p.start() |
266 | | - return p |
267 | | - except Exception as e: |
268 | | - print(f"Attempt {attempt + 1} failed with error: {e}") |
269 | | - attempt += 1 |
270 | | - time.sleep(1) |
271 | | - return start_process() |
272 | | - |
273 | | - p = start_process() |
274 | | - async_save_queue.append(p) |
275 | | - |
276 | 193 |
|
277 | 194 | class Trainer: |
278 | 195 | """ |
@@ -440,6 +357,8 @@ def __init__( |
440 | 357 |
|
441 | 358 | self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save |
442 | 359 | self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load |
| 360 | + if self.args.use_async_save: |
| 361 | + self._async_optimizer_saver = AsyncSaver() |
443 | 362 |
|
444 | 363 | if args.max_steps > 0: |
445 | 364 | logger.info("max_steps is given, it will override any value given in num_train_epochs") |
@@ -2308,9 +2227,6 @@ def _save_checkpoint(self, model, metrics=None): |
2308 | 2227 | # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" |
2309 | 2228 | self.runtime_timer.start("checkpoint saving time") |
2310 | 2229 |
|
2311 | | - if self.args.use_async_save: |
2312 | | - clear_async_save_task_queue() |
2313 | | - |
2314 | 2230 | # Save model checkpoint |
2315 | 2231 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" |
2316 | 2232 |
|
@@ -2353,10 +2269,8 @@ def _save_checkpoint(self, model, metrics=None): |
2353 | 2269 | save_path = os.path.join(output_dir, optimizer_name) |
2354 | 2270 | if self.args.use_async_save: |
2355 | 2271 | assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC" |
2356 | | - async_save_optimizer( |
2357 | | - state_dict, |
2358 | | - save_path, |
2359 | | - saved_signal_path=saved_signal_path, |
| 2272 | + self._async_optimizer_saver.run( |
| 2273 | + state_dict, save_path, saved_signal_path=saved_signal_path |
2360 | 2274 | ) |
2361 | 2275 | else: |
2362 | 2276 | self._save_ckpt_func(state_dict, save_path) |
|
0 commit comments