Skip to content

Commit f0f1d9d

Browse files
authored
[Distributed]Add async saver for optimizer (#8864)
* add async saver * add async saver * add async save * add async save * add async save * add async save * add async save * add async save
1 parent 6ca9b2a commit f0f1d9d

File tree

2 files changed

+131
-91
lines changed

2 files changed

+131
-91
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 5 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818

1919
import collections
2020
import contextlib
21-
import copy
2221
import inspect
2322
import math
24-
import multiprocessing
2523
import os
2624
import random
2725
import re
@@ -148,6 +146,7 @@
148146
)
149147
from .training_args import TrainingArguments
150148
from .utils import reshard as reshard_util
149+
from .utils.async_save import AsyncSaver
151150
from .utils.helper import ( # nested_truncate,
152151
broadcast_dp_optimizer,
153152
broadcast_moe_optimizer,
@@ -190,88 +189,6 @@
190189

191190
__all__ = ["Trainer"]
192191

193-
async_save_queue = []
194-
g_cpu_optimizer_state_dict = {}
195-
196-
197-
def _save_func(obj, name_mapping, path, saved_signal_path, protocol):
198-
if isinstance(obj, dict):
199-
for k, v in obj.items():
200-
if k == "master_weights" and isinstance(v, dict):
201-
for kk, vv in v.items():
202-
if isinstance(vv, paddle.Tensor):
203-
vv.name = name_mapping["master_weights"][kk]
204-
else:
205-
if k in name_mapping and isinstance(v, paddle.Tensor):
206-
v.name = name_mapping[k]
207-
208-
paddle.save(obj, path, protocol)
209-
# dump savd_siganl
210-
with open(saved_signal_path, mode="w+") as f:
211-
f.write("1")
212-
213-
214-
def check_exitcode(task):
215-
exitcode = task.exitcode
216-
if exitcode != 0:
217-
print(f"Error: save ckpt process failed with exitcode {exitcode}!!!")
218-
219-
220-
def clear_async_save_task_queue():
221-
"""
222-
wait until all async save task to be done.
223-
"""
224-
while len(async_save_queue) > 0:
225-
task = async_save_queue.pop()
226-
if task and task.is_alive():
227-
task.join(timeout=60)
228-
if task.is_alive():
229-
logger.error("Error: save ckpt process timeout!!!")
230-
async_save_queue.append(task)
231-
else:
232-
check_exitcode(task)
233-
else:
234-
check_exitcode(task)
235-
236-
237-
def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4):
238-
global g_cpu_optimizer_state_dict
239-
g_cpu_optimizer_state_dict.clear()
240-
name_mapping = {"master_weights": {}}
241-
for k, v in optimizer_state_dict.items():
242-
if k == "master_weights":
243-
g_cpu_optimizer_state_dict[k] = {}
244-
for kk, vv in v.items():
245-
g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
246-
name_mapping[k][kk] = vv.name
247-
elif k == "LR_Scheduler":
248-
g_cpu_optimizer_state_dict[k] = copy.deepcopy(v)
249-
else:
250-
g_cpu_optimizer_state_dict[k] = v.pin_memory()
251-
name_mapping[k] = v.name
252-
paddle.device.synchronize()
253-
clear_async_save_task_queue()
254-
255-
attempt = 0
256-
ctx = multiprocessing.get_context("spawn")
257-
258-
def start_process():
259-
nonlocal attempt
260-
try:
261-
p = ctx.Process(
262-
target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol)
263-
)
264-
p.start()
265-
return p
266-
except Exception as e:
267-
print(f"Attempt {attempt + 1} failed with error: {e}")
268-
attempt += 1
269-
time.sleep(1)
270-
return start_process()
271-
272-
p = start_process()
273-
async_save_queue.append(p)
274-
275192

276193
class Trainer:
277194
"""
@@ -439,6 +356,8 @@ def __init__(
439356

440357
self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save
441358
self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load
359+
if self.args.use_async_save:
360+
self._async_optimizer_saver = AsyncSaver()
442361

443362
if args.max_steps > 0:
444363
logger.info("max_steps is given, it will override any value given in num_train_epochs")
@@ -2323,9 +2242,6 @@ def _save_checkpoint(self, model, metrics=None):
23232242
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
23242243
self.runtime_timer.start("checkpoint saving time")
23252244

2326-
if self.args.use_async_save:
2327-
clear_async_save_task_queue()
2328-
23292245
# Save model checkpoint
23302246
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
23312247

@@ -2368,10 +2284,8 @@ def _save_checkpoint(self, model, metrics=None):
23682284
save_path = os.path.join(output_dir, optimizer_name)
23692285
if self.args.use_async_save:
23702286
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
2371-
async_save_optimizer(
2372-
state_dict,
2373-
save_path,
2374-
saved_signal_path=saved_signal_path,
2287+
self._async_optimizer_saver.run(
2288+
state_dict, save_path, saved_signal_path=saved_signal_path
23752289
)
23762290
else:
23772291
self._save_ckpt_func(state_dict, save_path)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import atexit
16+
import copy
17+
import multiprocessing
18+
import os
19+
import time
20+
21+
import paddle
22+
23+
from paddlenlp.utils.log import logger
24+
25+
26+
def _save_optimizer(obj, name_mapping, path, saved_signal_path, protocol):
27+
start_time = time.time()
28+
for k, v in obj.items():
29+
if k == "master_weights" and isinstance(v, dict):
30+
for kk, vv in v.items():
31+
if isinstance(vv, paddle.Tensor):
32+
vv.name = name_mapping["master_weights"][kk]
33+
else:
34+
if k in name_mapping and isinstance(v, paddle.Tensor):
35+
v.name = name_mapping[k]
36+
paddle.save(obj, path, protocol)
37+
# dump saved_signal
38+
with open(saved_signal_path, mode="w+") as f:
39+
f.write("1")
40+
f.flush()
41+
os.fsync(f.fileno())
42+
end_time = time.time()
43+
elapsed_time = end_time - start_time
44+
logger.info(f"Async save optimizer took {elapsed_time:.6f} seconds to execute.")
45+
46+
47+
class AsyncSaver:
48+
def __init__(self):
49+
self.context = multiprocessing.get_context("spawn")
50+
self.cpu_optimizer_state_dict = {}
51+
self.pool = self.context.Pool(1)
52+
self.result = None
53+
self.name_mapping = None
54+
55+
atexit.register(self.shutdown)
56+
57+
def run(self, optimizer_state_dict, path, saved_signal_path, protocol=4):
58+
logger.info(f"Started saving optimizer_state_dict to {os.path.abspath(path)}.")
59+
self._wait_for_previous_result()
60+
61+
self._reset_state(path, saved_signal_path, protocol)
62+
self._process_optimizer_state_dict(optimizer_state_dict)
63+
64+
self.result = self.pool.apply_async(
65+
_save_optimizer,
66+
args=(self.cpu_optimizer_state_dict, self.name_mapping, self.path, self.saved_signal_path, self.protocol),
67+
)
68+
69+
logger.info("Finished launching saving optimizer_state_dict process")
70+
71+
def _wait_for_previous_result(self):
72+
if self.result is not None:
73+
max_retries = 5
74+
for retries in range(max_retries):
75+
try:
76+
self.result.get()
77+
break
78+
except Exception as e:
79+
if retries == max_retries - 1:
80+
raise RuntimeError(f"Failed after {max_retries} retries during async save.")
81+
82+
time.sleep(1 + retries * 2)
83+
logger.warning(f"An error occurred during async save: {e}. Retrying...")
84+
self.result = self.pool.apply_async(
85+
_save_optimizer,
86+
args=(
87+
self.cpu_optimizer_state_dict,
88+
self.name_mapping,
89+
self.path,
90+
self.saved_signal_path,
91+
self.protocol,
92+
),
93+
)
94+
95+
if self.result.ready() and not self.result.successful():
96+
raise RuntimeError("The previous async save task failed.")
97+
else:
98+
pass
99+
100+
def _reset_state(self, path, saved_signal_path, protocol):
101+
self.cpu_optimizer_state_dict.clear()
102+
self.name_mapping = {"master_weights": {}}
103+
self.path = path
104+
self.saved_signal_path = saved_signal_path
105+
self.protocol = protocol
106+
107+
def _process_optimizer_state_dict(self, optimizer_state_dict):
108+
for k, v in optimizer_state_dict.items():
109+
if k == "master_weights":
110+
self.cpu_optimizer_state_dict[k] = {}
111+
for kk, vv in v.items():
112+
self.cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
113+
self.name_mapping[k][kk] = vv.name
114+
elif k == "LR_Scheduler":
115+
self.cpu_optimizer_state_dict[k] = copy.deepcopy(v)
116+
else:
117+
self.cpu_optimizer_state_dict[k] = v.pin_memory()
118+
self.name_mapping[k] = v.name
119+
paddle.device.synchronize()
120+
121+
def shutdown(self):
122+
self.pool.close()
123+
self.pool.join()
124+
125+
def __del__(self):
126+
self.shutdown()

0 commit comments

Comments
 (0)