Skip to content

Commit 7187299

Browse files
authored
fix format (#8878)
1 parent 4ebec1d commit 7187299

File tree

2 files changed

+131
-91
lines changed

2 files changed

+131
-91
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 5 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818

1919
import collections
2020
import contextlib
21-
import copy
2221
import inspect
2322
import math
24-
import multiprocessing
2523
import os
2624
import random
2725
import re
@@ -149,6 +147,7 @@
149147
)
150148
from .training_args import TrainingArguments
151149
from .utils import reshard as reshard_util
150+
from .utils.async_save import AsyncSaver
152151
from .utils.helper import ( # nested_truncate,
153152
broadcast_dp_optimizer,
154153
broadcast_moe_optimizer,
@@ -191,88 +190,6 @@
191190

192191
__all__ = ["Trainer"]
193192

194-
async_save_queue = []
195-
g_cpu_optimizer_state_dict = {}
196-
197-
198-
def _save_func(obj, name_mapping, path, saved_signal_path, protocol):
199-
if isinstance(obj, dict):
200-
for k, v in obj.items():
201-
if k == "master_weights" and isinstance(v, dict):
202-
for kk, vv in v.items():
203-
if isinstance(vv, paddle.Tensor):
204-
vv.name = name_mapping["master_weights"][kk]
205-
else:
206-
if k in name_mapping and isinstance(v, paddle.Tensor):
207-
v.name = name_mapping[k]
208-
209-
paddle.save(obj, path, protocol)
210-
# dump savd_siganl
211-
with open(saved_signal_path, mode="w+") as f:
212-
f.write("1")
213-
214-
215-
def check_exitcode(task):
216-
exitcode = task.exitcode
217-
if exitcode != 0:
218-
print(f"Error: save ckpt process failed with exitcode {exitcode}!!!")
219-
220-
221-
def clear_async_save_task_queue():
222-
"""
223-
wait until all async save task to be done.
224-
"""
225-
while len(async_save_queue) > 0:
226-
task = async_save_queue.pop()
227-
if task and task.is_alive():
228-
task.join(timeout=60)
229-
if task.is_alive():
230-
logger.error("Error: save ckpt process timeout!!!")
231-
async_save_queue.append(task)
232-
else:
233-
check_exitcode(task)
234-
else:
235-
check_exitcode(task)
236-
237-
238-
def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4):
239-
global g_cpu_optimizer_state_dict
240-
g_cpu_optimizer_state_dict.clear()
241-
name_mapping = {"master_weights": {}}
242-
for k, v in optimizer_state_dict.items():
243-
if k == "master_weights":
244-
g_cpu_optimizer_state_dict[k] = {}
245-
for kk, vv in v.items():
246-
g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
247-
name_mapping[k][kk] = vv.name
248-
elif k == "LR_Scheduler":
249-
g_cpu_optimizer_state_dict[k] = copy.deepcopy(v)
250-
else:
251-
g_cpu_optimizer_state_dict[k] = v.pin_memory()
252-
name_mapping[k] = v.name
253-
paddle.device.synchronize()
254-
clear_async_save_task_queue()
255-
256-
attempt = 0
257-
ctx = multiprocessing.get_context("spawn")
258-
259-
def start_process():
260-
nonlocal attempt
261-
try:
262-
p = ctx.Process(
263-
target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol)
264-
)
265-
p.start()
266-
return p
267-
except Exception as e:
268-
print(f"Attempt {attempt + 1} failed with error: {e}")
269-
attempt += 1
270-
time.sleep(1)
271-
return start_process()
272-
273-
p = start_process()
274-
async_save_queue.append(p)
275-
276193

277194
class Trainer:
278195
"""
@@ -440,6 +357,8 @@ def __init__(
440357

441358
self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save
442359
self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load
360+
if self.args.use_async_save:
361+
self._async_optimizer_saver = AsyncSaver()
443362

444363
if args.max_steps > 0:
445364
logger.info("max_steps is given, it will override any value given in num_train_epochs")
@@ -2308,9 +2227,6 @@ def _save_checkpoint(self, model, metrics=None):
23082227
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
23092228
self.runtime_timer.start("checkpoint saving time")
23102229

2311-
if self.args.use_async_save:
2312-
clear_async_save_task_queue()
2313-
23142230
# Save model checkpoint
23152231
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
23162232

@@ -2353,10 +2269,8 @@ def _save_checkpoint(self, model, metrics=None):
23532269
save_path = os.path.join(output_dir, optimizer_name)
23542270
if self.args.use_async_save:
23552271
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
2356-
async_save_optimizer(
2357-
state_dict,
2358-
save_path,
2359-
saved_signal_path=saved_signal_path,
2272+
self._async_optimizer_saver.run(
2273+
state_dict, save_path, saved_signal_path=saved_signal_path
23602274
)
23612275
else:
23622276
self._save_ckpt_func(state_dict, save_path)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import atexit
16+
import copy
17+
import multiprocessing
18+
import os
19+
import time
20+
21+
import paddle
22+
23+
from paddlenlp.utils.log import logger
24+
25+
26+
def _save_optimizer(obj, name_mapping, path, saved_signal_path, protocol):
27+
start_time = time.time()
28+
for k, v in obj.items():
29+
if k == "master_weights" and isinstance(v, dict):
30+
for kk, vv in v.items():
31+
if isinstance(vv, paddle.Tensor):
32+
vv.name = name_mapping["master_weights"][kk]
33+
else:
34+
if k in name_mapping and isinstance(v, paddle.Tensor):
35+
v.name = name_mapping[k]
36+
paddle.save(obj, path, protocol)
37+
# dump saved_signal
38+
with open(saved_signal_path, mode="w+") as f:
39+
f.write("1")
40+
f.flush()
41+
os.fsync(f.fileno())
42+
end_time = time.time()
43+
elapsed_time = end_time - start_time
44+
logger.info(f"Async save optimizer took {elapsed_time:.6f} seconds to execute.")
45+
46+
47+
class AsyncSaver:
48+
def __init__(self):
49+
self.context = multiprocessing.get_context("spawn")
50+
self.cpu_optimizer_state_dict = {}
51+
self.pool = self.context.Pool(1)
52+
self.result = None
53+
self.name_mapping = None
54+
55+
atexit.register(self.shutdown)
56+
57+
def run(self, optimizer_state_dict, path, saved_signal_path, protocol=4):
58+
logger.info(f"Started saving optimizer_state_dict to {os.path.abspath(path)}.")
59+
self._wait_for_previous_result()
60+
61+
self._reset_state(path, saved_signal_path, protocol)
62+
self._process_optimizer_state_dict(optimizer_state_dict)
63+
64+
self.result = self.pool.apply_async(
65+
_save_optimizer,
66+
args=(self.cpu_optimizer_state_dict, self.name_mapping, self.path, self.saved_signal_path, self.protocol),
67+
)
68+
69+
logger.info("Finished launching saving optimizer_state_dict process")
70+
71+
def _wait_for_previous_result(self):
72+
if self.result is not None:
73+
max_retries = 5
74+
for retries in range(max_retries):
75+
try:
76+
self.result.get()
77+
break
78+
except Exception as e:
79+
if retries == max_retries - 1:
80+
raise RuntimeError(f"Failed after {max_retries} retries during async save.")
81+
82+
time.sleep(1 + retries * 2)
83+
logger.warning(f"An error occurred during async save: {e}. Retrying...")
84+
self.result = self.pool.apply_async(
85+
_save_optimizer,
86+
args=(
87+
self.cpu_optimizer_state_dict,
88+
self.name_mapping,
89+
self.path,
90+
self.saved_signal_path,
91+
self.protocol,
92+
),
93+
)
94+
95+
if self.result.ready() and not self.result.successful():
96+
raise RuntimeError("The previous async save task failed.")
97+
else:
98+
pass
99+
100+
def _reset_state(self, path, saved_signal_path, protocol):
101+
self.cpu_optimizer_state_dict.clear()
102+
self.name_mapping = {"master_weights": {}}
103+
self.path = path
104+
self.saved_signal_path = saved_signal_path
105+
self.protocol = protocol
106+
107+
def _process_optimizer_state_dict(self, optimizer_state_dict):
108+
for k, v in optimizer_state_dict.items():
109+
if k == "master_weights":
110+
self.cpu_optimizer_state_dict[k] = {}
111+
for kk, vv in v.items():
112+
self.cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
113+
self.name_mapping[k][kk] = vv.name
114+
elif k == "LR_Scheduler":
115+
self.cpu_optimizer_state_dict[k] = copy.deepcopy(v)
116+
else:
117+
self.cpu_optimizer_state_dict[k] = v.pin_memory()
118+
self.name_mapping[k] = v.name
119+
paddle.device.synchronize()
120+
121+
def shutdown(self):
122+
self.pool.close()
123+
self.pool.join()
124+
125+
def __del__(self):
126+
self.shutdown()

0 commit comments

Comments
 (0)