Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 134 additions & 36 deletions python/paddle/distributed/checkpoint/load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from typing import TYPE_CHECKING
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么删除了 TYPE_CHECKING


import paddle
from paddle.base.framework import (
_current_expected_place,
)
from paddle.distributed.communication.group import is_initialized
from paddle.distributed.fleet.utils.log_util import logger

Expand All @@ -39,12 +42,12 @@ class ReadItem:
local_tensor_index: LocalTensorIndex
rank: int
dtype: str
cur_offset: tuple[int, ...]
storage_offset: tuple[int, ...]
lengths: tuple[int, ...]
cur_offset: tuple[int]
storage_offset: tuple[int]
lengths: tuple[int]


PATH_TO_CHECKPOINT_FILES: dict[str, tuple[list[str], list[str]]] = {}
PATH_TO_CHECKPOINT_FILES: dict[str, tuple[list, list]] = {}


def get_checkpoint_files(path, use_cache=True):
Expand Down Expand Up @@ -136,17 +139,79 @@ def get_rank_to_files(
)

rank_to_files = {}
for rank, local_files in enumerate(global_data_files):
if len(local_files) > 0:
local_files = [
f for f in local_files if f in all_necessary_files[rank]
]
rank_to_files[rank] = local_files
for rank, need_files in enumerate(all_necessary_files):
seen = set()
unique_need_files = [
f for f in need_files if not (f in seen or seen.add(f))
]
rank_to_files[rank] = unique_need_files
logger.debug(f"mapping rank_to_files:{rank_to_files}")
return rank_to_files, missing_keys


def get_local_load_files(rank_to_files):
def get_rank_to_read_files(rank_to_files, rank_to_local_data_files):
cross_node_file_names = []
rank_to_need_files = copy.deepcopy(rank_to_files)
for rank, need_files in rank_to_need_files.items():
local_data_files = rank_to_local_data_files[rank]
file_need_to_remove = []
for file in need_files:
if file not in local_data_files:
file_need_to_remove.append(file)
for file in file_need_to_remove:
need_files.remove(file)
cross_node_file_names += file_need_to_remove

not_read_file_ranks = []
for rank, files in rank_to_need_files.items():
if len(files) == 0:
not_read_file_ranks.append(rank)
for rank in not_read_file_ranks:
rank_to_need_files.pop(rank)

rank_load_files = _get_rank_to_read_files(rank_to_need_files)

for rank in not_read_file_ranks:
rank_load_files[rank] = []

cur_load_files = []
for rank, load_file in rank_load_files.items():
cur_load_files += load_file

unload_files = []
for file in cross_node_file_names:
if file not in cur_load_files:
unload_files.append(file)

file_to_ranks = {}
for rank, files in rank_to_local_data_files.items():
for file in files:
if file not in file_to_ranks:
file_to_ranks[file] = [rank]
else:
file_to_ranks[file].append(rank)

seen = set()
unload_files = [x for x in unload_files if not (x in seen or seen.add(x))]
for file in unload_files:
sub_rank_load_files = {}
for rank in file_to_ranks[file]:
sub_rank_load_files[rank] = rank_load_files[rank]
min_rank = min(
sub_rank_load_files,
key=lambda rank: (len(sub_rank_load_files[rank]), rank),
)
rank_load_files[min_rank].append(file)

cur_rank = paddle.distributed.get_rank()
if cur_rank in rank_load_files:
return rank_load_files[cur_rank]
else:
logger.warning(f"rank:{cur_rank} does not need to load checkpoint")
return []


def _get_rank_to_read_files(rank_to_files):
"""
Load files in a load-balanced manner.

Expand All @@ -170,7 +235,7 @@ def get_local_load_files(rank_to_files):
if file not in file_to_ranks:
file_to_ranks[file] = []
file_to_ranks[file].append(rank)
rank_to_not_read_files = copy.copy(rank_to_files)
rank_to_not_read_files = copy.deepcopy(rank_to_files)
rank_to_read_files = {rank: [] for rank in rank_to_not_read_files.keys()}
for file, ranks in file_to_ranks.items():
if len(ranks) == 1:
Expand Down Expand Up @@ -239,13 +304,7 @@ def update(rank_to_read_files, rank_to_not_read_files, rank_file):
logger.debug(
f"update rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}, ranks:{ranks}, rank_file:{rank_file}"
)

cur_rank = paddle.distributed.get_rank()
if cur_rank in rank_to_read_files:
return rank_to_read_files[cur_rank]
else:
logger.warning(f"rank:{cur_rank} does not need to load checkpoint")
return []
return rank_to_read_files


def get_load_infos(metadata_list, local_load_files, process_group, use_dist):
Expand Down Expand Up @@ -410,6 +469,7 @@ def load_state_dict(
path: str,
process_group: Group | None = None,
coordinator_rank: int = 0,
offload=False,
) -> None:
"""
Load the state_dict inplace from a checkpoint path.
Expand All @@ -419,7 +479,7 @@ def load_state_dict(
path(str): The directory to load checkpoint files.
process_group(paddle.distributed.collective.Group): ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.
coordinator_rank(int): The rank used to coordinate the checkpoint. Rank0 is used by default.

offload(bool): Whether to offload the checkpoint data from GPU to CPU.
Example:
.. code-block:: python

Expand Down Expand Up @@ -486,26 +546,46 @@ def load_state_dict(
if len(rank_to_files) <= 0:
return

local_load_files = get_local_load_files(rank_to_files)
cur_rank = paddle.distributed.get_rank()
global_local_data_files = []
if use_dist:
paddle.distributed.all_gather_object(
global_local_data_files,
{cur_rank: local_data_files},
process_group,
)
else:
global_local_data_files = [{cur_rank: local_data_files}]

source_state_dict = {}
for file in local_load_files:
source_state_dict[file] = paddle.load(os.path.join(path, file))
rank_to_local_data_files = {}
for d in global_local_data_files:
rank_to_local_data_files.update(d)

state_dict_in_cpu = []
for k, v in flat_state_dict.items():
if v.place.is_cpu_place():
state_dict_in_cpu.append(k)
flat_state_dict[k] = v.cuda()
local_load_files = get_rank_to_read_files(
rank_to_files, rank_to_local_data_files
)

_load_state_dict(flat_state_dict, source_state_dict, metadata_list)
source_state_dict = {}
for file in local_load_files:
if offload:
state_dict_numpy = paddle.load(
os.path.join(path, file), return_numpy=True
)
source_state_dict[file] = {
key: paddle.to_tensor(value, place=paddle.CPUPlace())
for key, value in state_dict_numpy.items()
}
else:
source_state_dict[file] = paddle.load(os.path.join(path, file))

for k, v in flat_state_dict.items():
if k in state_dict_in_cpu:
value = state_dict
for key in mapping[k]:
value = value[key]
paddle.assign(v.cpu(), value)
_load_state_dict(
flat_state_dict,
source_state_dict,
metadata_list,
process_group,
coordinator_rank,
offload,
)


def _load_state_dict(
Expand All @@ -514,8 +594,15 @@ def _load_state_dict(
metadata_list,
process_group=None,
coordinator_rank=0,
offload=False,
) -> None:
with paddle.base.dygraph.guard():

state_dict_in_cpu = {}
for k, v in target_state_dict.items():
if v.place.is_cpu_place():
state_dict_in_cpu[k] = v
target_state_dict[k] = v.cuda()
use_dist = True if paddle.distributed.get_world_size() > 1 else False
local_load_files = list(source_state_dict.keys())
# load_infos: {LocalTensorIndex: (rank, file_name)}, which local tensor located in which file, and the file is load in which rank.
Expand Down Expand Up @@ -543,6 +630,12 @@ def _load_state_dict(
storage_local_tensor = storage_state_dict[
item.local_tensor_index.tensor_key
]

if offload:
storage_local_tensor = paddle.to_tensor(
storage_local_tensor, place=_current_expected_place()
)

storage_offsets = item.storage_offset
storage_lengths = item.lengths
storage_ends = [
Expand Down Expand Up @@ -620,3 +713,8 @@ def _load_state_dict(
tmp_tensor, src=src_rank, group=process_group
)
paddle.assign(tmp_tensor, cur_chunk_tensor)

for k, v in target_state_dict.items():
if k in state_dict_in_cpu:
value = state_dict_in_cpu[k]
paddle.assign(v.cpu(), value)