Skip to content

Refactor push_to_hub #1883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
logger.info(f'images_dir: {images_dir}')
plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9)
if args.push_to_hub:
trainer._add_patterns_to_gitignore(['images/'])
trainer.push_to_hub()
run_info = {
'memory': trainer.perf['memory'],
Expand Down
10 changes: 8 additions & 2 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ def handle_compatibility(self: Union['SftArguments', 'InferArguments']) -> None:
if self.lora_target_regex:
self.target_regex = self.lora_target_regex

if getattr(self, 'push_hub_strategy', None):
self.hub_strategy = self.push_hub_strategy
if self.hub_strategy in ('push_last', 'push_best'):
self.hub_strategy = 'every_save'

def handle_custom_dataset_info(self: Union['SftArguments', 'InferArguments']):
if self.custom_dataset_info is None:
return
Expand Down Expand Up @@ -790,7 +795,7 @@ class SftArguments(ArgumentsBase):
hub_token: Optional[str] = field(
default=None, metadata={'help': 'SDK token can be found in https://modelscope.cn/my/myaccesstoken'})
hub_private_repo: bool = False
push_hub_strategy: Literal['end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'] = 'push_best'
hub_strategy: Literal['end', 'every_save', 'checkpoint', 'all_checkpoints'] = 'every_save'

# other
test_oom_error: bool = field(
Expand Down Expand Up @@ -874,6 +879,7 @@ class SftArguments(ArgumentsBase):
custom_train_dataset_path: List[str] = field(default_factory=list)
custom_val_dataset_path: List[str] = field(default_factory=list)
device_map_config_path: Optional[str] = None
push_hub_strategy: Literal['end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'] = 'push_best'

def _prepare_target_modules(self, target_modules) -> Union[List[str], str]:
if isinstance(target_modules, str):
Expand Down Expand Up @@ -1191,7 +1197,7 @@ def _init_training_args(self) -> None:
adam_epsilon=self.adam_epsilon,
hub_model_id=self.hub_model_id,
hub_private_repo=self.hub_private_repo,
push_hub_strategy=self.push_hub_strategy,
hub_strategy=self.push_hub_strategy,
hub_token=self.hub_token,
push_to_hub=self.push_to_hub,
resume_from_checkpoint=self.resume_from_checkpoint,
Expand Down
2 changes: 0 additions & 2 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class SwiftArgumentsMixin:
# ckpt only save model
save_only_model: bool = False
train_sampler_random: bool = True
push_hub_strategy: str = field(
default='push_best', metadata={'choices': {'end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'}})
acc_strategy: str = field(default='token', metadata={'choices': ['token', 'sentence']})
loss_name: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
additional_saved_files: Optional[List[str]] = None
Expand Down
172 changes: 2 additions & 170 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,191 +22,23 @@
from transformers.data.data_collator import DataCollator
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import unwrap_model
from transformers.trainer import (ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, CONFIG_NAME,
PREFIX_CHECKPOINT_DIR, SAFE_WEIGHTS_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME,
WEIGHTS_NAME, IntervalStrategy, Trainer, TrainerCallback, is_peft_available)
from transformers.trainer import PREFIX_CHECKPOINT_DIR, TRAINER_STATE_NAME, Trainer, TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import TrainingArguments
from transformers.utils import is_sagemaker_mp_enabled, is_torch_npu_available

from swift.hub import Repository
from swift.hub.check_model import check_local_model_is_latest
from swift.torchacc_utils import (save_ta_ddp_checkpoint, save_ta_fsdp_checkpoint, ta_load_optimizer_and_scheduler,
ta_save_optimizer_and_scheduler, ta_trim_graph)
from swift.tuners import SwiftModel
from swift.utils import check_json_format, create_ms_repo, get_logger, use_torchacc
from swift.utils import check_json_format, get_logger, use_torchacc
from swift.utils.constants import Invoke
from .optimizers.galore import create_optimizer_and_scheduler
from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model

logger = get_logger()


def _push_to_hub(self: Repository, commit_message: str = 'Commit files to Modelscope Hub', **kwargs):
blocking = kwargs.get('blocking', True)
self.push(commit_message)
if not blocking:
# Compatible with transformers
return None, None
else:
return None


class PushToMsHubMixin:
repo: Repository

def _add_patterns_to_file(self, file_name: str, patterns: List[str], commit_message: Optional[str] = None) -> None:
# Make sure we only do this on the main process
if not self.is_world_process_zero():
return
if isinstance(patterns, str):
patterns = [patterns]
if commit_message is None:
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'

# Get current file content
repo_dir = self.repo.model_dir
file_path = os.path.join(repo_dir, file_name)
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
current_content = f.read()
else:
current_content = ''
# Add the patterns to file
content = current_content
for pattern in patterns:
if pattern not in content:
if len(content) > 0 and not content.endswith('\n'):
content += '\n'
content += f'{pattern}\n'

# Write the file if it has changed
if content != current_content:
with open(file_path, 'w', encoding='utf-8') as f:
logger.debug(f'Writing {file_name} file. Content: {content}')
f.write(content)
self.repo.push(commit_message)

def _add_patterns_to_gitignore(self, patterns: List[str], commit_message: Optional[str] = None) -> None:
self._add_patterns_to_file('.gitignore', patterns, commit_message)

def _add_patterns_to_gitattributes(self, patterns: List[str], commit_message: Optional[str] = None) -> None:
new_patterns = []
suffix = 'filter=lfs diff=lfs merge=lfs -text'
for pattern in patterns:
if suffix not in pattern:
pattern = f'{pattern} {suffix}'
new_patterns.append(pattern)
file_name = '.gitattributes'
if commit_message is None:
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
self._add_patterns_to_file(file_name, new_patterns, commit_message)

def init_hf_repo(self) -> None:
"""init ms repo. Compatible with transformers>=4.34"""
self.init_git_repo(at_init=True)

def init_git_repo(self, at_init: bool = False) -> None:
if not self.is_world_process_zero():
return
if (os.path.exists(self.args.output_dir) and os.listdir(self.args.output_dir) and self.args.overwrite_output_dir
and at_init):
# directory not empty.
shutil.rmtree(self.args.output_dir)
self.args.hub_model_id = create_ms_repo(self.args.hub_model_id, self.args.hub_token, self.args.hub_private_repo)
self.repo = Repository(self.args.output_dir, self.args.hub_model_id)
self._add_patterns_to_gitattributes(['*.safetensors', '*.bin', '*.pt'])
self.repo.push_to_hub = MethodType(_push_to_hub, self.repo)
self.repo.local_dir = self.repo.model_dir # hf compatibility

# By default, ignore the checkpoint folders
if self.args.push_hub_strategy != 'all_checkpoints':
self._add_patterns_to_gitignore(['checkpoint-*/', 'tmp-checkpoint-*/'])

# Add 'runs/' to .gitignore, ignore tensorboard files
self._add_patterns_to_gitignore(['runs/'])

# Add '*.sagemaker' to .gitignore if using SageMaker
if os.environ.get('SM_TRAINING_ENV'):
self._add_patterns_to_gitignore(['*.sagemaker-uploading', '*.sagemaker-uploaded'],
'Add `*.sagemaker` patterns to .gitignore')

self.push_in_progress = None

def push_to_hub(self, commit_message: str = 'End of training', **kwargs) -> None:
# user calls manually `push_to_hub` with `self.args.push_to_hub = False`
create_model_card = kwargs.pop('create_model_card', None)
if not hasattr(self, 'repo'):
self.init_git_repo()
self.save_model(_internal_call=True)

if not self.is_world_process_zero():
return

self.repo.push_to_hub(commit_message, **kwargs)
# push separately the model card to be independent from the rest of the model
readme_path = os.path.join(self.args.output_dir, 'README.md')
if create_model_card is None:
create_model_card = not os.path.exists(readme_path)
if create_model_card and self.args.should_save:
model_name = kwargs.pop('model_name', None)
if model_name is None and self.args.should_save:
if self.args.hub_model_id is not None:
model_name = self.args.hub_model_id.split('/')[-1]
else:
model_name = os.path.basename(self.args.output_dir)
self.create_model_card(model_name=model_name, **kwargs)
self.repo.push_to_hub('update model card README.md', **kwargs)

def _push_from_checkpoint(self, checkpoint_folder: str) -> None:
"""Compatible with transformers>=4.32"""
# Only push from one node.
if not self.is_world_process_zero() or self.args.push_hub_strategy == 'end':
return
output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

try:
if self.args.push_hub_strategy == 'checkpoint':
# Temporarily move the checkpoint just saved for the push
tmp_checkpoint = os.path.join(output_dir, 'last-checkpoint')
# We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
# subfolder.
if os.path.isdir(tmp_checkpoint):
shutil.rmtree(tmp_checkpoint)
shutil.move(checkpoint_folder, tmp_checkpoint)

if self.args.save_strategy == IntervalStrategy.STEPS:
commit_message = f'Training in progress, step {self.state.global_step}'
else:
commit_message = f'Training in progress, epoch {int(self.state.epoch)}'
if self.args.push_hub_strategy == 'push_best':
folder, checkpoint_name = os.path.split(checkpoint_folder)
checkpoint_name = checkpoint_name.replace('tmp-checkpoint-', 'checkpoint-')
last_model_checkpoint = os.path.join(folder, checkpoint_name)
if last_model_checkpoint == self.state.best_model_checkpoint:
self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
else:
self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
except Exception as e:
logger.error(f'Error when pushing to hub: {e}')
finally:
if self.args.push_hub_strategy == 'checkpoint':
# Move back the checkpoint to its place
shutil.move(tmp_checkpoint, checkpoint_folder)


class SwiftMixin:

def __init__(self,
Expand Down
Loading
Loading