Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,10 @@ class TrainingArguments:
default=300,
metadata={"help": "Timeout seconds for downloading checkpoint from remote cluster."},
)
pdc_use_flash_device: Optional[bool] = field(
default=False,
metadata={"help": "Use flash device for storage of checkpoints and other usages"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down
89 changes: 87 additions & 2 deletions paddlenlp/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from collections import OrderedDict
from typing import Optional, Union

import paddle.distributed as dist
import requests
from filelock import FileLock
from huggingface_hub import get_hf_file_metadata, hf_hub_url
Expand All @@ -33,7 +34,13 @@
from .env import DOWNLOAD_SERVER, FAILED_STATUS, SUCCESS_STATUS
from .fault_tolerance import PDC_DOWNLOAD_ERROR
from .log import logger
from .pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool
from .pdc_sdk import (
FLASH_DEVICE,
PDCErrorCode,
PDCErrorMessageMap,
pdc_flash_device_available,
pdc_tool,
)

__all__ = ["get_weights_path_from_url"]

Expand Down Expand Up @@ -487,7 +494,7 @@ def download_from_pdc(remote_path, local_path, timeout):
"""

try:
base_dir, _ = os.path.split(os.path.normpath(remote_path))
base_dir, _ = os.path.split(os.path.normpath(local_path))
if not os.path.exists(base_dir) and base_dir != "":
os.makedirs(base_dir, exist_ok=True)
except Exception as e:
Expand All @@ -505,3 +512,81 @@ def download_from_pdc(remote_path, local_path, timeout):
raise RuntimeError(
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download object from PDC, remote_path: {remote_path}, local_path: {local_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
)


def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_device=False):
"""
Get static model from PDC. Use flash device if possible.
This function has to be called after distributed env is initialized in distributed mode.
Args:
remote_path (`str`):
remote path url for download
local_path (`str`):
local path to place downloaded object
timeout (`int`):
max wait time for download
enable_flash_device (`bool`):
Whether to use flash device
Returns:
str: path to load static model
"""
try:
base_dir, target_dir = os.path.split(os.path.normpath(local_path))
if not os.path.exists(base_dir) and base_dir != "":
os.makedirs(base_dir, exist_ok=True)
except Exception as e:
raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}")

assert target_dir != ".", f"{PDC_DOWNLOAD_ERROR}, illegal local_path: {local_path}."

flash_path = os.path.join(FLASH_DEVICE, target_dir)
persistent_path = local_path

device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
if device_id != 0:
logger.info("Waiting local process 0...")
dist.barrier()
return flash_path if (enable_flash_device and os.path.exists(flash_path)) else persistent_path

# step 1: load from flash device if possible
need_download_from_remote = True
need_backup_to_flash = False
if enable_flash_device and pdc_flash_device_available():
logger.info(f"flash device is available, checking status on {flash_path}...")
# skip download SC as default when flash device is available
need_download_from_remote = False
if os.path.exists(flash_path) and pdc_tool.pdc_flash_do_check(flash_path) == PDCErrorCode.Success:
logger.info("Static model checked successfully on flash device, ready to load...")
else:
logger.warning(
"flash device is available but no valid static model found on flash device, need to download from remote."
)
need_download_from_remote = True
need_backup_to_flash = True
else:
logger.info("Flash device is not enabled or available, will download static model from remote.")

# step 2: download from remote if neccesary
if need_download_from_remote:
logger.info("Beging download static model from remote...")
download_from_pdc(remote_path, persistent_path, timeout)
logger.info(f"downloaded static model from remote, path:{persistent_path}")

# step 3: backup to flash device if flash device is available
if enable_flash_device and need_backup_to_flash:
result = pdc_tool.pdc_backup_to_flash_device(persistent_path, flash_path)
if result == PDCErrorCode.Success:
logger.info(f"Backup static model to flash device {flash_path} successfully.")
else:
logger.error(f"Backup static model to flash device failed, error details: {PDCErrorMessageMap[result]}.")

# step 4: return flash path if available, otherwise return persistent path
if dist.get_world_size() > 1:
logger.info("Local node process done, waiting other nodes...")
dist.barrier()
if enable_flash_device and os.path.exists(flash_path):
logger.info(f"static model is ready on flash device, path: {flash_path}")
return flash_path
else:
logger.info(f"static model is only ready on persistent storage, path: {persistent_path}")
return persistent_path
76 changes: 69 additions & 7 deletions paddlenlp/utils/pdc_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import json
import os
import queue
import shutil
import subprocess
import threading
import time
from distutils.dir_util import copy_tree
from enum import Enum
from typing import List

Expand All @@ -28,6 +30,13 @@
TRAIN_CONFIG = "/root/paddlejob/workspace/env_run/longjob/train.conf"
TAR_BIN = "tar"

FLASH_DEVICE = os.getenv("PDC_FLASH_DEVICE", "/shared/dev/shm/flash")


def pdc_flash_device_available():
# TODO(@gexiao): need better check
return os.path.exists(FLASH_DEVICE)


class PDCErrorCode(Enum):
"""Error Code For PDCTools usage"""
Expand All @@ -48,6 +57,7 @@ class PDCErrorCode(Enum):
InvalidArgument = 1503
CommandTimeout = 1504
CheckSumCommandFail = 1505
CopyTreeFailed = 1506

UnknownError = 1999

Expand Down Expand Up @@ -491,14 +501,60 @@ def _download_file(self, remote_path: str, local_path: str) -> PDCErrorCode:
raise Exception(f"exec cmd {download_cmd_args} with error: {e}")
return error_code

def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode:
def _pdc_backup_failed_directory(self, path):
base_dir, target_path = os.path.split(os.path.normpath(path))
failed_path = os.path.join(base_dir, f"{target_path}_failed")
if os.path.exists(path):
if os.path.exists(failed_path):
shutil.rmtree(failed_path)
# Backup failed files for debug
os.rename(path, failed_path)

def pdc_backup_to_flash_device(self, persistent_path: str, flash_device_path: str) -> PDCErrorCode:
"""backup data to flash device

Args:
persistent_path str: persistent path
flash_device_path str: flash device path
"""
if not os.path.exists(persistent_path):
logger.error(f"{persistent_path} not exist")
return PDCErrorCode.LocalPathNotExist

logger.info("starting backup to flash device...")

# step 1: generate checksum for recovery
result = self.pdc_generate_dir_checksum(persistent_path)
if result != PDCErrorCode.Success:
logger.error(f"[Error] [pdc_sdk] generating checksum for {persistent_path} failed")
return result

# step 2: copy persistent data to flash device
try:
copy_tree(persistent_path, flash_device_path)
logger.info(f"backup {persistent_path} to {flash_device_path} successed.")
except Exception as e:
logger.error(f"[Error] [pdc_sdk] copy tree {persistent_path} to {flash_device_path} failed, error: {e}")
self._pdc_backup_failed_directory(flash_device_path)
return PDCErrorCode.CopyTreeFailed

# step 3: do checksum for storage on flash device
result = self.pdc_flash_do_check(flash_device_path)
if result == PDCErrorCode.Success:
return result

logger.error(f"[Error] [pdc_sdk] checksum failed on {flash_device_path} after copy, backup for debug")
self._pdc_backup_failed_directory(flash_device_path)
return result

def pdc_generate_dir_checksum(self, path: str) -> PDCErrorCode:
"""
Args
:param localPath:
:return:
"""
if not os.path.exists(path):
logger.error(f"pdc_fc_generate_checksum gi{path} not exist")
logger.error(f"pdc_generate_dir_checksum gi{path} not exist")
return PDCErrorCode.CommandFail
generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "generateSum", "-path", f"{path}"]
error_code = PDCErrorCode.Success
Expand All @@ -512,14 +568,14 @@ def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode:
return PDCErrorCode.CheckSumCommandFail
return error_code

def pdc_fc_do_check(self, path: str) -> PDCErrorCode:
def pdc_flash_do_check(self, path: str) -> PDCErrorCode:
"""
Args
:param localPath:
:return:
"""
if not os.path.exists(path):
logger.error(f"pdc_fc_do_check {path} not exist")
logger.error(f"pdc_flash_do_check {path} not exist")
return PDCErrorCode.CommandFail
generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "checkSum", "-path", f"{path}"]
error_code = PDCErrorCode.Success
Expand All @@ -528,8 +584,12 @@ def pdc_fc_do_check(self, path: str) -> PDCErrorCode:
res, error_code = self._exec_cmd(generate_checksum_args)
if error_code == PDCErrorCode.Success:
logger.info(f"check_sum {path} successfully")
else:
logger.error(f"[Error] [pdc_sdk] check_sum {path} failed, error code: {error_code}")
self._pdc_backup_failed_directory(path)
except Exception as e:
logger.error(f"exec cmd {generate_checksum_args} with error: {e}")
logger.error(f"[Error] [pdc_sdk] exec cmd {generate_checksum_args} with error: {e}")
self._pdc_backup_failed_directory(path)
return PDCErrorCode.CheckSumCommandFail
return error_code

Expand Down Expand Up @@ -558,8 +618,10 @@ def _clean_tmp_files(self, tmp_files: List[str]):
PDCErrorCode.AFSToolsNotExist: "afs tools not exist",
PDCErrorCode.TrainConfigNotExist: "train config not exist",
PDCErrorCode.LocalPathNotExist: "local path not exist",
PDCErrorCode.CommandFail: "download command fail",
PDCErrorCode.CommandFail: "pdc agent command fail",
PDCErrorCode.CalculateHashFail: "calculate hash fail",
PDCErrorCode.InvalidArgument: "invalid argument",
PDCErrorCode.CommandTimeout: "command timeout",
PDCErrorCode.CommandTimeout: "pdc agent command timeout",
PDCErrorCode.CheckSumCommandFail: "checksum command fail",
PDCErrorCode.CopyTreeFailed: "copy directory failed",
}
Loading