Skip to content

Commit f363d2d

Browse files
committed
[LLM] adapt pdc sdk
1 parent 55fc611 commit f363d2d

File tree

2 files changed

+85
-11
lines changed

2 files changed

+85
-11
lines changed

paddlenlp/utils/downloader.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@
3838
FLASH_DEVICE,
3939
PDCErrorCode,
4040
PDCErrorMessageMap,
41-
pdc_backup_to_flash_device,
4241
pdc_flash_device_available,
43-
pdc_flash_do_check,
4442
pdc_tool,
4543
)
4644

@@ -519,6 +517,18 @@ def download_from_pdc(remote_path, local_path, timeout):
519517
def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_device=False):
520518
"""
521519
Get static model from PDC. Use flash device if possible.
520+
This function has to be called after distributed env is initialized in distributed mode.
521+
Args:
522+
remote_path (`str`):
523+
remote path url for download
524+
local_path (`str`):
525+
local path to place downloaded object
526+
timeout (`int`):
527+
max wait time for download
528+
enable_flash_device (`bool`):
529+
Whether to use flash device
530+
Returns:
531+
str: path to load static model
522532
"""
523533
try:
524534
base_dir, target_dir = os.path.split(os.path.normpath(local_path))
@@ -527,6 +537,8 @@ def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_devic
527537
except Exception as e:
528538
raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}")
529539

540+
assert target_dir != ".", f"{PDC_DOWNLOAD_ERROR}, illegal local_path: {local_path}."
541+
530542
flash_path = os.path.join(FLASH_DEVICE, target_dir)
531543
persistent_path = local_path
532544

@@ -543,15 +555,16 @@ def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_devic
543555
logger.info(f"flash device is available, checking status on {flash_path}...")
544556
# skip download SC as default when flash device is available
545557
need_download_from_remote = False
546-
if os.path.exists(flash_path) and pdc_flash_do_check(flash_path):
558+
if os.path.exists(flash_path) and pdc_tool.pdc_flash_do_check(flash_path) == PDCErrorCode.Success:
547559
logger.info("Static model checked successfully on flash device, ready to load...")
548560
else:
549561
logger.warning(
550562
"flash device is available but no valid static model found on flash device, need to download from remote."
551563
)
552-
shutil.rmtree(flash_path)
553564
need_download_from_remote = True
554565
need_backup_to_flash = True
566+
else:
567+
logger.info("Flash device is not enabled or available, will download static model from remote.")
555568

556569
# step 2: download from remote if neccesary
557570
if need_download_from_remote:
@@ -561,12 +574,11 @@ def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_devic
561574

562575
# step 3: backup to flash device if flash device is available
563576
if enable_flash_device and need_backup_to_flash:
564-
result = pdc_backup_to_flash_device(persistent_path, flash_path)
577+
result = pdc_tool.pdc_backup_to_flash_device(persistent_path, flash_path)
565578
if result == PDCErrorCode.Success:
566579
logger.info(f"Backup static model to flash device {flash_path} successfully.")
567580
else:
568581
logger.error(f"Backup static model to flash device failed, error details: {PDCErrorMessageMap[result]}.")
569-
shutil.rmtree(flash_path)
570582

571583
# step 4: return flash path if available, otherwise return persistent path
572584
if dist.get_world_size() > 1:

paddlenlp/utils/pdc_sdk.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import json
1616
import os
1717
import queue
18+
import shutil
1819
import subprocess
1920
import threading
2021
import time
22+
from distutils.dir_util import copy_tree
2123
from enum import Enum
2224
from typing import List
2325

@@ -28,6 +30,13 @@
2830
TRAIN_CONFIG = "/root/paddlejob/workspace/env_run/longjob/train.conf"
2931
TAR_BIN = "tar"
3032

33+
FLASH_DEVICE = os.getenv("PDC_FLASH_DEVICE", "/shared/dev/shm/flash")
34+
35+
36+
def pdc_flash_device_available():
37+
# TODO(@gexiao): need better check
38+
return os.path.exists(FLASH_DEVICE)
39+
3140

3241
class PDCErrorCode(Enum):
3342
"""Error Code For PDCTools usage"""
@@ -48,6 +57,7 @@ class PDCErrorCode(Enum):
4857
InvalidArgument = 1503
4958
CommandTimeout = 1504
5059
CheckSumCommandFail = 1505
60+
CopyTreeFailed = 1506
5161

5262
UnknownError = 1999
5363

@@ -491,14 +501,60 @@ def _download_file(self, remote_path: str, local_path: str) -> PDCErrorCode:
491501
raise Exception(f"exec cmd {download_cmd_args} with error: {e}")
492502
return error_code
493503

494-
def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode:
504+
def _pdc_backup_failed_directory(self, path):
505+
base_dir, target_path = os.path.split(os.path.normpath(path))
506+
failed_path = os.path.join(base_dir, f"{target_path}_failed")
507+
if os.path.exists(path):
508+
if os.path.exists(failed_path):
509+
shutil.rmtree(failed_path)
510+
# Backup failed files for debug
511+
os.rename(path, failed_path)
512+
513+
def pdc_backup_to_flash_device(self, persistent_path: str, flash_device_path: str) -> PDCErrorCode:
514+
"""backup data to flash device
515+
516+
Args:
517+
persistent_path str: persistent path
518+
flash_device_path str: flash device path
519+
"""
520+
if not os.path.exists(persistent_path):
521+
logger.error(f"{persistent_path} not exist")
522+
return PDCErrorCode.LocalPathNotExist
523+
524+
logger.info("starting backup to flash device...")
525+
526+
# step 1: generate checksum for recovery
527+
result = self.pdc_generate_dir_checksum(persistent_path)
528+
if result != PDCErrorCode.Success:
529+
logger.error(f"[Error] [pdc_sdk] generating checksum for {persistent_path} failed")
530+
return result
531+
532+
# step 2: copy persistent data to flash device
533+
try:
534+
copy_tree(persistent_path, flash_device_path)
535+
logger.info(f"backup {persistent_path} to {flash_device_path} successed.")
536+
except Exception as e:
537+
logger.error(f"[Error] [pdc_sdk] copy tree {persistent_path} to {flash_device_path} failed, error: {e}")
538+
self._pdc_backup_failed_directory(flash_device_path)
539+
return PDCErrorCode.CopyTreeFailed
540+
541+
# step 3: do checksum for storage on flash device
542+
result = self.pdc_flash_do_check(flash_device_path)
543+
if result == PDCErrorCode.Success:
544+
return result
545+
546+
logger.error(f"[Error] [pdc_sdk] checksum failed on {flash_device_path} after copy, backup for debug")
547+
self._pdc_backup_failed_directory(flash_device_path)
548+
return result
549+
550+
def pdc_generate_dir_checksum(self, path: str) -> PDCErrorCode:
495551
"""
496552
Args
497553
:param localPath:
498554
:return:
499555
"""
500556
if not os.path.exists(path):
501-
logger.error(f"pdc_fc_generate_checksum gi{path} not exist")
557+
logger.error(f"pdc_generate_dir_checksum gi{path} not exist")
502558
return PDCErrorCode.CommandFail
503559
generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "generateSum", "-path", f"{path}"]
504560
error_code = PDCErrorCode.Success
@@ -512,14 +568,14 @@ def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode:
512568
return PDCErrorCode.CheckSumCommandFail
513569
return error_code
514570

515-
def pdc_fc_do_check(self, path: str) -> PDCErrorCode:
571+
def pdc_flash_do_check(self, path: str) -> PDCErrorCode:
516572
"""
517573
Args
518574
:param localPath:
519575
:return:
520576
"""
521577
if not os.path.exists(path):
522-
logger.error(f"pdc_fc_do_check {path} not exist")
578+
logger.error(f"pdc_flash_do_check {path} not exist")
523579
return PDCErrorCode.CommandFail
524580
generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "checkSum", "-path", f"{path}"]
525581
error_code = PDCErrorCode.Success
@@ -528,8 +584,12 @@ def pdc_fc_do_check(self, path: str) -> PDCErrorCode:
528584
res, error_code = self._exec_cmd(generate_checksum_args)
529585
if error_code == PDCErrorCode.Success:
530586
logger.info(f"check_sum {path} successfully")
587+
else:
588+
logger.error(f"[Error] [pdc_sdk] check_sum {path} failed, error code: {error_code}")
589+
self._pdc_backup_failed_directory(path)
531590
except Exception as e:
532-
logger.error(f"exec cmd {generate_checksum_args} with error: {e}")
591+
logger.error(f"[Error] [pdc_sdk] exec cmd {generate_checksum_args} with error: {e}")
592+
self._pdc_backup_failed_directory(path)
533593
return PDCErrorCode.CheckSumCommandFail
534594
return error_code
535595

@@ -562,4 +622,6 @@ def _clean_tmp_files(self, tmp_files: List[str]):
562622
PDCErrorCode.CalculateHashFail: "calculate hash fail",
563623
PDCErrorCode.InvalidArgument: "invalid argument",
564624
PDCErrorCode.CommandTimeout: "command timeout",
625+
PDCErrorCode.CheckSumCommandFail: "checksum command fail",
626+
PDCErrorCode.CopyTreeFailed: "copy directory failed",
565627
}

0 commit comments

Comments
 (0)