Skip to content

Commit 92c209a

Browse files
authored
[LLM] support flash device on static model (#9619)
* [LLM] support flash device on static model * [LLM] adapt pdc sdk
1 parent b30014b commit 92c209a

File tree

3 files changed

+160
-9
lines changed

3 files changed

+160
-9
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,10 @@ class TrainingArguments:
866866
default=300,
867867
metadata={"help": "Timeout seconds for downloading checkpoint from remote cluster."},
868868
)
869+
pdc_use_flash_device: Optional[bool] = field(
870+
default=False,
871+
metadata={"help": "Use flash device for storage of checkpoints and other usages"},
872+
)
869873

870874
def __post_init__(self):
871875
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))

paddlenlp/utils/downloader.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from collections import OrderedDict
2525
from typing import Optional, Union
2626

27+
import paddle.distributed as dist
2728
import requests
2829
from filelock import FileLock
2930
from huggingface_hub import get_hf_file_metadata, hf_hub_url
@@ -33,7 +34,13 @@
3334
from .env import DOWNLOAD_SERVER, FAILED_STATUS, SUCCESS_STATUS
3435
from .fault_tolerance import PDC_DOWNLOAD_ERROR
3536
from .log import logger
36-
from .pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool
37+
from .pdc_sdk import (
38+
FLASH_DEVICE,
39+
PDCErrorCode,
40+
PDCErrorMessageMap,
41+
pdc_flash_device_available,
42+
pdc_tool,
43+
)
3744

3845
__all__ = ["get_weights_path_from_url"]
3946

@@ -487,7 +494,7 @@ def download_from_pdc(remote_path, local_path, timeout):
487494
"""
488495

489496
try:
490-
base_dir, _ = os.path.split(os.path.normpath(remote_path))
497+
base_dir, _ = os.path.split(os.path.normpath(local_path))
491498
if not os.path.exists(base_dir) and base_dir != "":
492499
os.makedirs(base_dir, exist_ok=True)
493500
except Exception as e:
@@ -505,3 +512,81 @@ def download_from_pdc(remote_path, local_path, timeout):
505512
raise RuntimeError(
506513
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download object from PDC, remote_path: {remote_path}, local_path: {local_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
507514
)
515+
516+
517+
def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_device=False):
518+
"""
519+
Get static model from PDC. Use flash device if possible.
520+
This function has to be called after distributed env is initialized in distributed mode.
521+
Args:
522+
remote_path (`str`):
523+
remote path url for download
524+
local_path (`str`):
525+
local path to place downloaded object
526+
timeout (`int`):
527+
max wait time for download
528+
enable_flash_device (`bool`):
529+
Whether to use flash device
530+
Returns:
531+
str: path to load static model
532+
"""
533+
try:
534+
base_dir, target_dir = os.path.split(os.path.normpath(local_path))
535+
if not os.path.exists(base_dir) and base_dir != "":
536+
os.makedirs(base_dir, exist_ok=True)
537+
except Exception as e:
538+
raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}")
539+
540+
assert target_dir != ".", f"{PDC_DOWNLOAD_ERROR}, illegal local_path: {local_path}."
541+
542+
flash_path = os.path.join(FLASH_DEVICE, target_dir)
543+
persistent_path = local_path
544+
545+
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
546+
if device_id != 0:
547+
logger.info("Waiting local process 0...")
548+
dist.barrier()
549+
return flash_path if (enable_flash_device and os.path.exists(flash_path)) else persistent_path
550+
551+
# step 1: load from flash device if possible
552+
need_download_from_remote = True
553+
need_backup_to_flash = False
554+
if enable_flash_device and pdc_flash_device_available():
555+
logger.info(f"flash device is available, checking status on {flash_path}...")
556+
# skip download SC as default when flash device is available
557+
need_download_from_remote = False
558+
if os.path.exists(flash_path) and pdc_tool.pdc_flash_do_check(flash_path) == PDCErrorCode.Success:
559+
logger.info("Static model checked successfully on flash device, ready to load...")
560+
else:
561+
logger.warning(
562+
"flash device is available but no valid static model found on flash device, need to download from remote."
563+
)
564+
need_download_from_remote = True
565+
need_backup_to_flash = True
566+
else:
567+
logger.info("Flash device is not enabled or available, will download static model from remote.")
568+
569+
# step 2: download from remote if neccesary
570+
if need_download_from_remote:
571+
logger.info("Beging download static model from remote...")
572+
download_from_pdc(remote_path, persistent_path, timeout)
573+
logger.info(f"downloaded static model from remote, path:{persistent_path}")
574+
575+
# step 3: backup to flash device if flash device is available
576+
if enable_flash_device and need_backup_to_flash:
577+
result = pdc_tool.pdc_backup_to_flash_device(persistent_path, flash_path)
578+
if result == PDCErrorCode.Success:
579+
logger.info(f"Backup static model to flash device {flash_path} successfully.")
580+
else:
581+
logger.error(f"Backup static model to flash device failed, error details: {PDCErrorMessageMap[result]}.")
582+
583+
# step 4: return flash path if available, otherwise return persistent path
584+
if dist.get_world_size() > 1:
585+
logger.info("Local node process done, waiting other nodes...")
586+
dist.barrier()
587+
if enable_flash_device and os.path.exists(flash_path):
588+
logger.info(f"static model is ready on flash device, path: {flash_path}")
589+
return flash_path
590+
else:
591+
logger.info(f"static model is only ready on persistent storage, path: {persistent_path}")
592+
return persistent_path

paddlenlp/utils/pdc_sdk.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import json
1616
import os
1717
import queue
18+
import shutil
1819
import subprocess
1920
import threading
2021
import time
22+
from distutils.dir_util import copy_tree
2123
from enum import Enum
2224
from typing import List
2325

@@ -28,6 +30,13 @@
2830
TRAIN_CONFIG = "/root/paddlejob/workspace/env_run/longjob/train.conf"
2931
TAR_BIN = "tar"
3032

33+
FLASH_DEVICE = os.getenv("PDC_FLASH_DEVICE", "/shared/dev/shm/flash")
34+
35+
36+
def pdc_flash_device_available():
37+
# TODO(@gexiao): need better check
38+
return os.path.exists(FLASH_DEVICE)
39+
3140

3241
class PDCErrorCode(Enum):
3342
"""Error Code For PDCTools usage"""
@@ -48,6 +57,7 @@ class PDCErrorCode(Enum):
4857
InvalidArgument = 1503
4958
CommandTimeout = 1504
5059
CheckSumCommandFail = 1505
60+
CopyTreeFailed = 1506
5161

5262
UnknownError = 1999
5363

@@ -491,14 +501,60 @@ def _download_file(self, remote_path: str, local_path: str) -> PDCErrorCode:
491501
raise Exception(f"exec cmd {download_cmd_args} with error: {e}")
492502
return error_code
493503

494-
def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode:
504+
def _pdc_backup_failed_directory(self, path):
505+
base_dir, target_path = os.path.split(os.path.normpath(path))
506+
failed_path = os.path.join(base_dir, f"{target_path}_failed")
507+
if os.path.exists(path):
508+
if os.path.exists(failed_path):
509+
shutil.rmtree(failed_path)
510+
# Backup failed files for debug
511+
os.rename(path, failed_path)
512+
513+
def pdc_backup_to_flash_device(self, persistent_path: str, flash_device_path: str) -> PDCErrorCode:
514+
"""backup data to flash device
515+
516+
Args:
517+
persistent_path str: persistent path
518+
flash_device_path str: flash device path
519+
"""
520+
if not os.path.exists(persistent_path):
521+
logger.error(f"{persistent_path} not exist")
522+
return PDCErrorCode.LocalPathNotExist
523+
524+
logger.info("starting backup to flash device...")
525+
526+
# step 1: generate checksum for recovery
527+
result = self.pdc_generate_dir_checksum(persistent_path)
528+
if result != PDCErrorCode.Success:
529+
logger.error(f"[Error] [pdc_sdk] generating checksum for {persistent_path} failed")
530+
return result
531+
532+
# step 2: copy persistent data to flash device
533+
try:
534+
copy_tree(persistent_path, flash_device_path)
535+
logger.info(f"backup {persistent_path} to {flash_device_path} successed.")
536+
except Exception as e:
537+
logger.error(f"[Error] [pdc_sdk] copy tree {persistent_path} to {flash_device_path} failed, error: {e}")
538+
self._pdc_backup_failed_directory(flash_device_path)
539+
return PDCErrorCode.CopyTreeFailed
540+
541+
# step 3: do checksum for storage on flash device
542+
result = self.pdc_flash_do_check(flash_device_path)
543+
if result == PDCErrorCode.Success:
544+
return result
545+
546+
logger.error(f"[Error] [pdc_sdk] checksum failed on {flash_device_path} after copy, backup for debug")
547+
self._pdc_backup_failed_directory(flash_device_path)
548+
return result
549+
550+
def pdc_generate_dir_checksum(self, path: str) -> PDCErrorCode:
495551
"""
496552
Args
497553
:param localPath:
498554
:return:
499555
"""
500556
if not os.path.exists(path):
501-
logger.error(f"pdc_fc_generate_checksum gi{path} not exist")
557+
logger.error(f"pdc_generate_dir_checksum gi{path} not exist")
502558
return PDCErrorCode.CommandFail
503559
generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "generateSum", "-path", f"{path}"]
504560
error_code = PDCErrorCode.Success
@@ -512,14 +568,14 @@ def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode:
512568
return PDCErrorCode.CheckSumCommandFail
513569
return error_code
514570

515-
def pdc_fc_do_check(self, path: str) -> PDCErrorCode:
571+
def pdc_flash_do_check(self, path: str) -> PDCErrorCode:
516572
"""
517573
Args
518574
:param localPath:
519575
:return:
520576
"""
521577
if not os.path.exists(path):
522-
logger.error(f"pdc_fc_do_check {path} not exist")
578+
logger.error(f"pdc_flash_do_check {path} not exist")
523579
return PDCErrorCode.CommandFail
524580
generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "checkSum", "-path", f"{path}"]
525581
error_code = PDCErrorCode.Success
@@ -528,8 +584,12 @@ def pdc_fc_do_check(self, path: str) -> PDCErrorCode:
528584
res, error_code = self._exec_cmd(generate_checksum_args)
529585
if error_code == PDCErrorCode.Success:
530586
logger.info(f"check_sum {path} successfully")
587+
else:
588+
logger.error(f"[Error] [pdc_sdk] check_sum {path} failed, error code: {error_code}")
589+
self._pdc_backup_failed_directory(path)
531590
except Exception as e:
532-
logger.error(f"exec cmd {generate_checksum_args} with error: {e}")
591+
logger.error(f"[Error] [pdc_sdk] exec cmd {generate_checksum_args} with error: {e}")
592+
self._pdc_backup_failed_directory(path)
533593
return PDCErrorCode.CheckSumCommandFail
534594
return error_code
535595

@@ -558,8 +618,10 @@ def _clean_tmp_files(self, tmp_files: List[str]):
558618
PDCErrorCode.AFSToolsNotExist: "afs tools not exist",
559619
PDCErrorCode.TrainConfigNotExist: "train config not exist",
560620
PDCErrorCode.LocalPathNotExist: "local path not exist",
561-
PDCErrorCode.CommandFail: "download command fail",
621+
PDCErrorCode.CommandFail: "pdc agent command fail",
562622
PDCErrorCode.CalculateHashFail: "calculate hash fail",
563623
PDCErrorCode.InvalidArgument: "invalid argument",
564-
PDCErrorCode.CommandTimeout: "command timeout",
624+
PDCErrorCode.CommandTimeout: "pdc agent command timeout",
625+
PDCErrorCode.CheckSumCommandFail: "checksum command fail",
626+
PDCErrorCode.CopyTreeFailed: "copy directory failed",
565627
}

0 commit comments

Comments
 (0)