Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5451d31
[Unified checkpoint] update optimizer async save signal
DesmonDay Aug 21, 2024
4f0b61a
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
gongel Sep 4, 2024
68470aa
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Sep 6, 2024
15e83e2
update paddlepaddle
DesmonDay Sep 6, 2024
6837b2f
split param
DesmonDay Sep 10, 2024
633d742
add save for split param
DesmonDay Sep 24, 2024
55186d7
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 9, 2024
b6aa309
fix save split_param
DesmonDay Oct 10, 2024
bf5d72b
add load uc split_param
DesmonDay Oct 11, 2024
9fdaae2
update uc files
DesmonDay Oct 14, 2024
9a210db
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 14, 2024
19071ef
update uc files
DesmonDay Oct 14, 2024
223e089
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 15, 2024
ae9ddce
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 16, 2024
4ab0df1
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 22, 2024
cbbc074
update split_param loading
DesmonDay Oct 24, 2024
7678fad
mkdir unified_checkpoint directory
DesmonDay Oct 25, 2024
238888d
rename file
DesmonDay Oct 25, 2024
780040e
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 25, 2024
b219ba6
update async handler
DesmonDay Oct 25, 2024
dbd13df
update files
DesmonDay Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,569 changes: 0 additions & 2,569 deletions paddlenlp/trainer/plugins/unified_checkpoint.py

This file was deleted.

6 changes: 1 addition & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@
from .argparser import strtobool
from .integrations import get_reporting_integration_callbacks
from .plugins.timer import RuntimeTimer, get_timers, set_timers
from .plugins.unified_checkpoint import UnifiedCheckpointHandler
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
Expand Down Expand Up @@ -144,6 +143,7 @@
speed_metrics,
)
from .training_args import TrainingArguments
from .unified_checkpoint import UnifiedCheckpointHandler
from .utils import reshard as reshard_util
from .utils.async_save import AsyncSaver
from .utils.helper import ( # nested_truncate,
Expand Down Expand Up @@ -598,7 +598,6 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
if use_unified_checkpoint:
self.unified_checkpoint_handler.load_unified_checkpoint(
self.model,
self.optimizer,
resume_from_checkpoint,
)
logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.")
Expand Down Expand Up @@ -1241,7 +1240,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
if self.args.unified_checkpoint:
self.unified_checkpoint_handler.load_unified_checkpoint(
self.model,
self.optimizer,
self.state.best_model_checkpoint,
)
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
Expand Down Expand Up @@ -1289,7 +1287,6 @@ def _load_best_model_from_peft_checkpoint(self):
if self.args.unified_checkpoint:
self.unified_checkpoint_handler.load_unified_checkpoint(
self.model,
self.optimizer,
self.state.best_model_checkpoint,
)
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
Expand Down Expand Up @@ -2775,7 +2772,6 @@ def _load_optimizer_and_scheduler(self, checkpoint):
opt_state_dict = None
else:
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
args=self.args,
model=self.model,
optimizer=self.optimizer,
resume_from_checkpoint=checkpoint,
Expand Down
6 changes: 6 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,12 @@
f"but got logging_steps={self.logging_steps}."
)

if "split_param" in sharding_parallel_config:
assert self.sharding == [ShardingOption.SHARD_OP], "Only sharding stage1 support split_param."
assert (

Check warning on line 1407 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1405-L1407

Added lines #L1405 - L1407 were not covered by tests
self.amp_master_grad
), "If `split_param` in sharding_parallel_config, `amp_master_grad` must be True."

fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)

Expand Down
15 changes: 15 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .unified_checkpoint import UnifiedCheckpointHandler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__all__ 加一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

250 changes: 250 additions & 0 deletions paddlenlp/trainer/unified_checkpoint/async_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Asynchronous unified checkpoint handler."""

import multiprocessing
import os
import time
from multiprocessing import shared_memory

import paddle
import paddle.distributed as dist

from paddlenlp.transformers.utils import is_safetensors_available
from paddlenlp.utils.log import logger

if is_safetensors_available():
from safetensors.numpy import save_file as safe_save_file

from .shared_memory_utils import (
_read_state_dict_from_shm,
_traverse_copy_to_shm,
create_meta_dict,
)

__all__ = ["AsyncCheckpointHandler"]


class AsyncCheckpointHandler:
def __init__(self, args):
# Mainly for asynchronous saving.
self.args = args
self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1

Check warning on line 43 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L42-L43

Added lines #L42 - L43 were not covered by tests

self._shm_model_weight = None
self._shm_master_weight = None
self._shm_optimizer_weight = None
self._meta_dict_model = None
self._meta_dict_master_weight = None
self._meta_dict_optim = None
self._process_model_weight = None
self._process_master_weight = None
self._process_optimizer_weight = None
self._lock = None
self._shared_save_model_flag = None
self._shared_save_master_weight_flag = None
self._shared_save_optimizer_flag = None

Check warning on line 57 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L45-L57

Added lines #L45 - L57 were not covered by tests

if "async_save" in self.args.unified_checkpoint_config:
self._lock = multiprocessing.Lock()
self._shared_save_model_path = multiprocessing.Array("c", 100000)
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_model_flag = multiprocessing.Array("i", 1)
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)

Check warning on line 69 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L59-L69

Added lines #L59 - L69 were not covered by tests

def _file_save_async_or_sync(
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
):
if is_sync:
for k in list(state_dict.keys()):
if isinstance(state_dict[k], paddle.Tensor):
state_dict[k] = state_dict.pop(k).cpu().numpy()
safe_save_file(state_dict, path, metadata={"format": "np"})

Check warning on line 78 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L74-L78

Added lines #L74 - L78 were not covered by tests
else:
if state_dict_type == "model_weight":
if self._shm_model_weight is None:
self._meta_dict_model, buffer_size = create_meta_dict(state_dict)
self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
shm_state_dict = self._shm_model_weight
meta_dict = self._meta_dict_model
shared_save_flag = self._shared_save_model_flag
shared_save_path = self._shared_save_model_path
shared_save_signal_path = self._shared_save_model_signal_path
if self._process_model_weight is None:
self._process_model_weight = multiprocessing.Process(

Check warning on line 90 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L80-L90

Added lines #L80 - L90 were not covered by tests
target=self._save_file_async_in_process,
args=(
meta_dict,
self._shm_model_weight.name,
self._shared_save_model_flag,
self._shared_save_model_path,
self._shared_save_model_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_model_weight.start()
process = self._process_model_weight
elif state_dict_type == "master_weight":
if self._shm_master_weight is None:
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
shm_state_dict = self._shm_master_weight
meta_dict = self._meta_dict_master_weight
shared_save_flag = self._shared_save_master_weight_flag
shared_save_path = self._shared_save_master_weight_path
shared_save_signal_path = self._shared_save_master_weight_signal_path
if self._process_master_weight is None:
self._process_master_weight = multiprocessing.Process(

Check warning on line 115 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L103-L115

Added lines #L103 - L115 were not covered by tests
target=self._save_file_async_in_process,
args=(
meta_dict,
self._shm_master_weight.name,
self._shared_save_master_weight_flag,
self._shared_save_master_weight_path,
self._shared_save_master_weight_signal_path,
self._lock,
"model_weight"
if "skip_save_model_weight" in self.args.unified_checkpoint_config
else state_dict_type,
self.global_rank,
),
)
self._process_master_weight.start()
process = self._process_master_weight
elif state_dict_type == "optimizer_weight":
if self._shm_optimizer_weight is None:
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
shm_state_dict = self._shm_optimizer_weight
meta_dict = self._meta_dict_optim
shared_save_flag = self._shared_save_optimizer_flag
shared_save_path = self._shared_save_optimizer_path
shared_save_signal_path = self._shared_save_optimizer_signal_path
if self._process_optimizer_weight is None:
self._process_optimizer_weight = multiprocessing.Process(

Check warning on line 142 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L130-L142

Added lines #L130 - L142 were not covered by tests
target=self._save_file_async_in_process,
args=(
meta_dict,
self._shm_optimizer_weight.name,
self._shared_save_optimizer_flag,
self._shared_save_optimizer_path,
self._shared_save_optimizer_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_optimizer_weight.start()
process = self._process_optimizer_weight

Check warning on line 156 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L155-L156

Added lines #L155 - L156 were not covered by tests

while True: # wait until no process is saving.
flag_value = shared_save_flag[0]
if flag_value == 0:
break
if not process.is_alive():
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
time.sleep(0.5)
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")

Check warning on line 165 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L158-L165

Added lines #L158 - L165 were not covered by tests
# only save model weight or save master weight, we enter this loop.
self._reset_and_update(shared_save_path, path)
self._reset_and_update(shared_save_signal_path, signal_path)
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
with self._lock:
shared_save_flag[0] = 1

Check warning on line 171 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L167-L171

Added lines #L167 - L171 were not covered by tests

def _save_file_async_in_process(
self,
meta_dict,
shm_name,
shared_save_flag,
shared_save_path,
shared_save_signal_path,
lock,
state_dict_type,
global_rank,
):
shm = shared_memory.SharedMemory(name=shm_name)
while True:
flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value.
if flag_value == -1: # stop process
break
if flag_value == 0: # nothing to save
continue
if flag_value == 1: # need to save
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
logger.info(f"Start to async save {path}")
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
safe_save_file(state_dict, path, {"format": "np"})
del state_dict
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
paddle.save(global_rank, saved_signal_path)
with lock:
shared_save_flag[0] = 0
time.sleep(0.5)
shm.close()

Check warning on line 203 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L184-L203

Added lines #L184 - L203 were not covered by tests

def _reset_and_update(self, shared_array, new_value):
# clear array
for i in range(len(shared_array)):
shared_array[i] = b"\0"

Check warning on line 208 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L207-L208

Added lines #L207 - L208 were not covered by tests
# update array
encoded_value = new_value.encode("utf-8")
shared_array[: len(encoded_value)] = encoded_value

Check warning on line 211 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L210-L211

Added lines #L210 - L211 were not covered by tests

def unlink_shared_memory(self):
if not ("async_save" in self.args.unified_checkpoint_config):
return

Check warning on line 215 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L214-L215

Added lines #L214 - L215 were not covered by tests

if self._shared_save_model_flag is not None:
while self._shared_save_model_flag[0] > 0: # async process is saving
if not self._process_model_weight.is_alive():
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_model_flag[0] = -1
if self._shared_save_master_weight_flag is not None:
while self._shared_save_master_weight_flag[0] > 0:
if not self._process_master_weight.is_alive():
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_master_weight_flag[0] = -1
if self._shared_save_optimizer_flag is not None:
while self._shared_save_optimizer_flag[0] > 0:
if not self._process_optimizer_weight.is_alive():
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_optimizer_flag[0] = -1

Check warning on line 234 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L217-L234

Added lines #L217 - L234 were not covered by tests

if self._shm_model_weight is not None:
self._shm_model_weight.close()
self._shm_model_weight.unlink()
self._shm_model_weight = None
if self._shm_master_weight is not None:
self._shm_master_weight.close()
self._shm_master_weight.unlink()
self._shm_master_weight = None
if self._shm_optimizer_weight is not None:
self._shm_optimizer_weight.close()
self._shm_optimizer_weight.unlink()
self._shm_optimizer_weight = None

Check warning on line 247 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L236-L247

Added lines #L236 - L247 were not covered by tests

if paddle.distributed.get_world_size() > 1:
dist.barrier()

Check warning on line 250 in paddlenlp/trainer/unified_checkpoint/async_handler.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/async_handler.py#L249-L250

Added lines #L249 - L250 were not covered by tests
Loading
Loading