Skip to content

Commit ae9ddce

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP into add_split_param
2 parents 223e089 + 697a4cc commit ae9ddce

File tree

8 files changed

+114
-41
lines changed

8 files changed

+114
-41
lines changed

llm/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,15 @@ PaddleNLP 支持多个主流大模型的 SFT、LoRA、Prefix Tuning 等精调策
115115
样例数据:
116116

117117
```text
118-
{"src": "类型#裙*颜色#蓝色*风格#清新*图案#蝴蝶结", "tgt": "裙身处采用立体蝴蝶结装饰辅以蓝色条带点缀,令衣身造型饱满富有层次的同时为其注入一丝甜美气息。将女孩清新娇俏的一面衬托而出。"}
118+
{"src": "Give three tips for staying healthy.", "tgt": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."}
119119
...
120120
```
121121

122-
为了方便测试,我们也提供了广告生成数据集可以直接使用
122+
为了方便测试,我们也提供了[tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)demo 数据集可以直接使用
123123

124124
```shell
125-
wget https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz
126-
tar -zxvf AdvertiseGen.tar.gz
125+
wget https://bj.bcebos.com/paddlenlp/datasets/examples/alpaca_demo.gz
126+
tar -xvf alpaca_demo.gz
127127
```
128128

129129
#### 2.2 全参精调:SFT

llm/config/llama/lora_argument.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"gradient_accumulation_steps": 4,
77
"per_device_eval_batch_size": 8,
88
"eval_accumulation_steps":16,
9-
"num_train_epochs": 3,
9+
"num_train_epochs": 1,
1010
"learning_rate": 3e-04,
1111
"warmup_steps": 30,
1212
"logging_steps": 1,

llm/config/llama/sft_argument.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"gradient_accumulation_steps": 2,
77
"per_device_eval_batch_size": 8,
88
"eval_accumulation_steps":16,
9-
"num_train_epochs": 3,
9+
"num_train_epochs": 1,
1010
"learning_rate": 3e-05,
1111
"warmup_steps": 30,
1212
"logging_steps": 1,

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,25 @@ def __init__(self, args):
136136
self._process_master_weight = None
137137
self._process_optimizer_weight = None
138138
self._lock = None
139-
self._shared_save_path = None
140139
self._shared_save_model_flag = None
141140
self._shared_save_master_weight_flag = None
142141
self._shared_save_optimizer_flag = None
143142

144143
if "async_save" in self.args.unified_checkpoint_config:
145144
self._lock = multiprocessing.Lock()
146145
self._shared_save_model_path = multiprocessing.Array("c", 100000)
146+
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
147147
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
148+
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
148149
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
150+
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
149151
self._shared_save_model_flag = multiprocessing.Array("i", 1)
150152
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
151153
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)
152154

153-
def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"):
155+
def _file_save_async_or_sync(
156+
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
157+
):
154158
if is_sync:
155159
for k in list(state_dict.keys()):
156160
if isinstance(state_dict[k], paddle.Tensor):
@@ -165,6 +169,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
165169
meta_dict = self._meta_dict_model
166170
shared_save_flag = self._shared_save_model_flag
167171
shared_save_path = self._shared_save_model_path
172+
shared_save_signal_path = self._shared_save_model_signal_path
168173
if self._process_model_weight is None:
169174
self._process_model_weight = multiprocessing.Process(
170175
target=self._save_file_async_in_process,
@@ -173,12 +178,14 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
173178
self._shm_model_weight.name,
174179
self._shared_save_model_flag,
175180
self._shared_save_model_path,
181+
self._shared_save_model_signal_path,
176182
self._lock,
177183
state_dict_type,
178184
self.global_rank,
179185
),
180186
)
181187
self._process_model_weight.start()
188+
process = self._process_model_weight
182189
elif state_dict_type == "master_weight":
183190
if self._shm_master_weight is None:
184191
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
@@ -187,6 +194,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
187194
meta_dict = self._meta_dict_master_weight
188195
shared_save_flag = self._shared_save_master_weight_flag
189196
shared_save_path = self._shared_save_master_weight_path
197+
shared_save_signal_path = self._shared_save_master_weight_signal_path
190198
if self._process_master_weight is None:
191199
self._process_master_weight = multiprocessing.Process(
192200
target=self._save_file_async_in_process,
@@ -195,6 +203,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
195203
self._shm_master_weight.name,
196204
self._shared_save_master_weight_flag,
197205
self._shared_save_master_weight_path,
206+
self._shared_save_master_weight_signal_path,
198207
self._lock,
199208
"model_weight"
200209
if "skip_save_model_weight" in self.args.unified_checkpoint_config
@@ -203,6 +212,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
203212
),
204213
)
205214
self._process_master_weight.start()
215+
process = self._process_master_weight
206216
elif state_dict_type == "optimizer_weight":
207217
if self._shm_optimizer_weight is None:
208218
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
@@ -211,6 +221,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
211221
meta_dict = self._meta_dict_optim
212222
shared_save_flag = self._shared_save_optimizer_flag
213223
shared_save_path = self._shared_save_optimizer_path
224+
shared_save_signal_path = self._shared_save_optimizer_signal_path
214225
if self._process_optimizer_weight is None:
215226
self._process_optimizer_weight = multiprocessing.Process(
216227
target=self._save_file_async_in_process,
@@ -219,21 +230,26 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
219230
self._shm_optimizer_weight.name,
220231
self._shared_save_optimizer_flag,
221232
self._shared_save_optimizer_path,
233+
self._shared_save_optimizer_signal_path,
222234
self._lock,
223235
state_dict_type,
224236
self.global_rank,
225237
),
226238
)
227239
self._process_optimizer_weight.start()
240+
process = self._process_optimizer_weight
228241

229242
while True: # wait until no process is saving.
230243
flag_value = shared_save_flag[0]
231244
if flag_value == 0:
232245
break
246+
if not process.is_alive():
247+
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
233248
time.sleep(0.5)
234249
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
235250
# only save model weight or save master weight, we enter this loop.
236251
self._reset_and_update(shared_save_path, path)
252+
self._reset_and_update(shared_save_signal_path, signal_path)
237253
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
238254
with self._lock:
239255
shared_save_flag[0] = 1
@@ -244,6 +260,7 @@ def _save_file_async_in_process(
244260
shm_name,
245261
shared_save_flag,
246262
shared_save_path,
263+
shared_save_signal_path,
247264
lock,
248265
state_dict_type,
249266
global_rank,
@@ -257,11 +274,12 @@ def _save_file_async_in_process(
257274
continue
258275
if flag_value == 1: # need to save
259276
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
277+
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
260278
logger.info(f"Start to async save {path}")
261279
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
262280
safe_save_file(state_dict, path, {"format": "np"})
263281
del state_dict
264-
saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}")
282+
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
265283
paddle.save(global_rank, saved_signal_path)
266284
with lock:
267285
shared_save_flag[0] = 0
@@ -276,7 +294,7 @@ def _reset_and_update(self, shared_array, new_value):
276294
encoded_value = new_value.encode("utf-8")
277295
shared_array[: len(encoded_value)] = encoded_value
278296

279-
def save_unified_checkpoint(self, model, optimizer, output_dir):
297+
def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None):
280298
"""save unified checkpoint
281299
282300
Args:
@@ -313,6 +331,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
313331

314332
save_directory = output_dir
315333
os.makedirs(save_directory, exist_ok=True)
334+
if signal_dir is not None:
335+
os.makedirs(signal_dir, exist_ok=True) # only for async save
316336

317337
# save model weights
318338
if not skip_save_model_weight:
@@ -325,6 +345,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
325345
self._file_save_async_or_sync(
326346
state_dict,
327347
path=os.path.join(save_directory, shard_file),
348+
signal_path=signal_dir,
328349
is_sync=is_sync_save,
329350
state_dict_type="model_weight",
330351
)
@@ -394,7 +415,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
394415
if self.args.dataset_rank == 0 or self.args.use_expert_parallel:
395416
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)
396417

397-
def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir):
418+
def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir):
398419
paddle.device.cuda.empty_cache()
399420

400421
# gather global master_weights status.
@@ -446,12 +467,14 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
446467
self._file_save_async_or_sync(
447468
optim_state_dict,
448469
path=os.path.join(output_dir, optimizer_name),
470+
signal_path=signal_dir,
449471
is_sync=is_sync_save,
450472
state_dict_type="optimizer_weight",
451473
)
452474
self._file_save_async_or_sync(
453475
master_weights,
454476
path=os.path.join(output_dir, master_weights_name),
477+
signal_path=signal_dir,
455478
is_sync=is_sync_save,
456479
state_dict_type="master_weight",
457480
)
@@ -501,17 +524,19 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):
501524

502525
return returned_optim_state_dict
503526

504-
def save_unified_optimizer(self, model, optimizer, output_dir):
527+
def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
505528
"""save unified optimizer
506529
507530
Args:
508531
model (PretrainedModel): model used to get key mapping.
509532
optimizer (Optimizer): optimizer to save
510533
output_dir (str): Save directory.
534+
signal_dir (str): Asynchronous saving signal directory.
511535
512536
"""
537+
513538
if paddle.distributed.get_world_size() <= 1:
514-
save_single_card_optimizer(model, optimizer, output_dir)
539+
save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal
515540
return
516541

517542
if (
@@ -530,7 +555,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
530555
optim_state_dict.pop("LR_Scheduler")
531556

532557
if "ignore_merge_optimizer" in self.args.unified_checkpoint_config:
533-
self.save_non_merge_optimizer(model, optim_state_dict, master_weights, output_dir)
558+
self.save_non_merge_optimizer(model, optim_state_dict, master_weights, output_dir, signal_dir)
534559
return
535560

536561
# Split into naive optimizer params and master weights.
@@ -547,20 +572,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir):
547572
paddle.device.cuda.empty_cache()
548573
save_directory = output_dir
549574
os.makedirs(save_directory, exist_ok=True)
575+
if signal_dir is not None:
576+
os.makedirs(signal_dir, exist_ok=True)
550577

551578
is_sync_save = True
552579
if "async_save" in self.args.unified_checkpoint_config:
553580
is_sync_save = False
554581
self._file_save_async_or_sync(
555582
optim_state_dict,
556583
path=os.path.join(save_directory, shard_optim_file),
584+
signal_path=signal_dir,
557585
is_sync=is_sync_save,
558586
state_dict_type="optimizer_weight",
559587
)
560588
if master_weight_state_dict is not None:
561589
self._file_save_async_or_sync(
562590
master_weight_state_dict,
563591
path=os.path.join(save_directory, shard_master_weight_file),
592+
signal_path=signal_dir,
564593
is_sync=is_sync_save,
565594
state_dict_type="master_weight",
566595
)
@@ -634,14 +663,20 @@ def unlink_shared_memory(self):
634663

635664
if self._shared_save_model_flag is not None:
636665
while self._shared_save_model_flag[0] > 0: # async process is saving
666+
if not self._process_model_weight.is_alive():
667+
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
637668
time.sleep(0.5)
638669
self._shared_save_model_flag[0] = -1
639670
if self._shared_save_master_weight_flag is not None:
640671
while self._shared_save_master_weight_flag[0] > 0:
672+
if not self._process_master_weight.is_alive():
673+
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
641674
time.sleep(0.5)
642675
self._shared_save_master_weight_flag[0] = -1
643676
if self._shared_save_optimizer_flag is not None:
644677
while self._shared_save_optimizer_flag[0] > 0:
678+
if not self._process_optimizer_weight.is_alive():
679+
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
645680
time.sleep(0.5)
646681
self._shared_save_optimizer_flag[0] = -1
647682

@@ -658,7 +693,8 @@ def unlink_shared_memory(self):
658693
self._shm_optimizer_weight.unlink()
659694
self._shm_optimizer_weight = None
660695

661-
dist.barrier()
696+
if paddle.distributed.get_world_size() > 1:
697+
dist.barrier()
662698

663699

664700
def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):

0 commit comments

Comments
 (0)