Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion python/paddle/distributed/checkpoint/save_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.
from __future__ import annotations

import multiprocessing
import os
import time
from typing import TYPE_CHECKING

import paddle
Expand All @@ -30,6 +32,49 @@
from paddle import Tensor
from paddle.distributed.collective import Group

async_save_queue = []


def check_exitcode(task):
exitcode = task.exitcode
if exitcode != 0:
logger.error(
f"Error: save ckpt process failed with exitcode {exitcode}!!!"
)


def clear_async_save_task_queue():
"""
wait until all async save task to be done.
"""
while len(async_save_queue) > 0:
task = async_save_queue.pop()
if task and task.is_alive():
task.join(timeout=60)
if task.is_alive():
logger.error("Error: save ckpt process timeout!!!")
async_save_queue.append(task)
else:
check_exitcode(task)
else:
check_exitcode(task)


def copy_dict_to_cpu(nested_dict):
"""
Copy the paddle.Tensor objects in the nested dictionary to the CPU and return a new dict.
"""
new_dict = {}
for key, value in nested_dict.items():
if isinstance(value, paddle.Tensor):
new_dict[key] = value.cpu()
paddle.device.synchronize()
elif isinstance(value, dict):
new_dict[key] = copy_dict_to_cpu(value)
else:
new_dict[key] = value
return new_dict


def check_file_name(file_name, process_group):
all_unique_id = []
Expand Down Expand Up @@ -102,6 +147,7 @@ def save_state_dict(
path: str,
process_group: Group | None = None,
coordinator_rank: int = 0,
async_save: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的 docstring 部分也同步增加参数说明,另外中文文档也同步修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,单独提交 pr 完善

) -> None:
"""
Save the state_dict of model to path.
Expand Down Expand Up @@ -233,4 +279,32 @@ def save_state_dict(
dedup_tensor(
local_state_dict, local_storage_metadata, metadata.storage_metadata
)
paddle.save(local_state_dict, os.path.join(path, file_name))

if async_save:
cpu_state_dict = copy_dict_to_cpu(local_state_dict)
clear_async_save_task_queue()

attempt = 0
ctx = multiprocessing.get_context("spawn")

def start_process():
nonlocal attempt
try:
p = ctx.Process(
target=paddle.save,
args=(cpu_state_dict, os.path.join(path, file_name)),
)
p.start()
return p
except Exception as e:
logger.error(
f"Attempt {attempt + 1} failed with error: {e}"
)
attempt += 1
time.sleep(1)
return start_process()

p = start_process()
async_save_queue.append(p)
else:
paddle.save(local_state_dict, os.path.join(path, file_name))
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import random
import tempfile
import time
from functools import reduce

import numpy as np
Expand Down Expand Up @@ -167,7 +168,7 @@ def run_dy2static(self, tmp_ckpt_path):
lr_scheduler.step()
if step == 2:
state_dict = dist_model.state_dict()
dist.save_state_dict(state_dict, tmp_ckpt_path)
dist.save_state_dict(state_dict, tmp_ckpt_path, async_save=True)
if step > 2:
numpy_array = np.array(loss)
array_bytes = numpy_array.tobytes()
Expand All @@ -182,6 +183,8 @@ def run_dy2static(self, tmp_ckpt_path):
if step >= 9:
break

time.sleep(10)

loss_after_load = []
for step, inputs in enumerate(dist_loader()):
if step < 2:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def setUp(self):
"acc_step": "1",
"FLAGS_embedding_deterministic": "1",
"FLAGS_cudnn_deterministic": "1",
"FLAGS_enable_pir_api": "1",
}
self._changeable_envs = {
"backend": ["gpu"],
Expand Down