Skip to content

Support PAI compat #373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions swift/llm/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ def llm_dpo(args: DPOArguments) -> str:
if is_master():
images_dir = os.path.join(args.output_dir, 'images')
logger.info(f'images_dir: {images_dir}')
tb_dir = os.path.join(args.output_dir, 'runs')
plot_images(images_dir, tb_dir, ['train/loss'], 0.9)
plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9)
if args.push_to_hub:
trainer._add_patterns_to_gitignore(['images/'])
trainer.push_to_hub()
Expand Down
3 changes: 1 addition & 2 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
if is_master():
images_dir = os.path.join(args.output_dir, 'images')
logger.info(f'images_dir: {images_dir}')
tb_dir = os.path.join(args.output_dir, 'runs')
plot_images(images_dir, tb_dir, ['train/loss'], 0.9)
plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9)
if args.push_to_hub:
trainer._add_patterns_to_gitignore(['images/'])
trainer.push_to_hub()
Expand Down
23 changes: 21 additions & 2 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from swift import get_logger
from swift.hub import HubApi, ModelScopeConfig
from swift.utils import (add_version_to_work_dir, broadcast_string,
get_dist_setting, is_dist, is_master, is_mp)
get_dist_setting, get_pai_tensorboard_dir, is_dist,
is_master, is_mp, is_pai_training_job)
from .dataset import (DATASET_MAPPING, get_custom_dataset, get_dataset,
register_dataset)
from .model import (MODEL_MAPPING, dtype_mapping,
Expand Down Expand Up @@ -52,7 +53,7 @@ class SftArguments:
f"template_type choices: {list(TEMPLATE_MAPPING.keys()) + ['AUTO']}"
})
output_dir: str = 'output'
add_output_dir_suffix: bool = True
add_output_dir_suffix: Optional[bool] = None
ddp_backend: Literal['nccl', 'gloo', 'mpi', 'ccl'] = 'nccl'

seed: int = 42
Expand Down Expand Up @@ -214,6 +215,8 @@ def prepare_target_modules(self, target_modules):

def __post_init__(self) -> None:
handle_compatibility(self)
if is_pai_training_job():
handle_pai_compat(self)
ds_config_folder = os.path.join(__file__, '..', '..', 'ds_config')
if self.deepspeed_config_path == 'default-zero2':
self.deepspeed_config_path = os.path.abspath(
Expand Down Expand Up @@ -270,6 +273,8 @@ def __post_init__(self) -> None:
if not dist.is_initialized():
dist.init_process_group(backend=self.ddp_backend)

if self.add_output_dir_suffix is None:
self.add_output_dir_suffix = True
if self.add_output_dir_suffix:
self.output_dir = os.path.join(self.output_dir, self.model_type)
self.output_dir = add_version_to_work_dir(self.output_dir)
Expand Down Expand Up @@ -906,3 +911,17 @@ def handle_dataset_mixture(args: SftArguments, train_dataset,
return concatenate_datasets([train_dataset, mixed_dataset])
else:
return train_dataset


def handle_pai_compat(args: SftArguments) -> None:
assert is_pai_training_job() is True
logger.info('Handle pai compat...')
pai_tensorboard_dir = get_pai_tensorboard_dir()
if args.logging_dir is None and pai_tensorboard_dir is not None:
args.logging_dir = pai_tensorboard_dir
logger.info(f'Setting args.logging_dir: {args.logging_dir}')
if args.add_output_dir_suffix is None:
args.add_output_dir_suffix = False
logger.info(
f'Setting args.add_output_dir_suffix: {args.add_output_dir_suffix}'
)
3 changes: 2 additions & 1 deletion swift/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
get_model_info, is_ddp_plus_mp, is_dist,
is_local_master, is_master, is_mp, is_on_same_device,
seed_everything, show_layers, time_synchronize)
from .utils import (add_version_to_work_dir, check_json_format, lower_bound,
from .utils import (add_version_to_work_dir, check_json_format,
get_pai_tensorboard_dir, is_pai_training_job, lower_bound,
parse_args, read_multi_line, test_time, upper_bound)
20 changes: 18 additions & 2 deletions swift/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime as dt
import os
import re
import sys
import time
from typing import (Any, Callable, List, Mapping, Optional, Sequence, Tuple,
Type, TypeVar)
Expand Down Expand Up @@ -67,8 +68,15 @@ def add_version_to_work_dir(work_dir: str) -> str:
def parse_args(class_type: Type[_T],
argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
parser = HfArgumentParser([class_type])
args, remaining_args = parser.parse_args_into_dataclasses(
argv, return_remaining_strings=True)
if argv is None:
argv = sys.argv[1:]
if len(argv) > 0 and argv[0].endswith('.json'):
json_path = os.path.abspath(os.path.expanduser(argv[0]))
args, = parser.parse_json_file(json_path)
remaining_args = argv[1:]
else:
args, remaining_args = parser.parse_args_into_dataclasses(
argv, return_remaining_strings=True)
return args, remaining_args


Expand Down Expand Up @@ -131,3 +139,11 @@ def read_multi_line() -> str:
res[-1] = text[:-2]
break
return ''.join(res)


def is_pai_training_job() -> bool:
return 'PAI_TRAINING_JOB_ID' in os.environ


def get_pai_tensorboard_dir() -> Optional[str]:
return os.environ.get('PAI_OUTPUT_TENSORBOARD')
7 changes: 7 additions & 0 deletions tests/llm/config/sft.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"model_type": "qwen-1_8b-chat",
"dataset": "jd-sentiment-zh",
"output_dir": "output/pai_test",
"train_dataset_sample": 100,
"eval_steps": 5
}
24 changes: 24 additions & 0 deletions tests/llm/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,30 @@ def test_dpo(self):
load_dataset_config=True,
val_dataset_sample=2))

def test_pai_compat(self):
if not __name__ == '__main__':
# ignore citest error in github
return
from swift.llm import sft_main, infer_main
os.environ['PAI_TRAINING_JOB_ID'] = '123456'
folder = os.path.join(os.path.dirname(__file__), 'config')
tensorboard_dir = os.path.join('output/pai_test', 'pai_tensorboard')
os.environ['PAI_OUTPUT_TENSORBOARD'] = tensorboard_dir
sft_json = os.path.join(folder, 'sft.json')
infer_json = os.path.join(folder, 'infer.json')
output = sft_main([sft_json])
print()
infer_args = {
'ckpt_dir': output['best_model_checkpoint'],
'val_dataset_sample': 2,
'load_dataset_config': True,
}
import json
with open(infer_json, 'w') as f:
json.dump(infer_args, f, ensure_ascii=False, indent=4)
infer_main([infer_json])
os.environ.pop('PAI_TRAINING_JOB_ID')


def data_collate_fn(batch: List[Dict[str, Any]],
tokenizer) -> Dict[str, Tensor]:
Expand Down