Skip to content

Refactor push_to_hub #1883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
logger.info(f'images_dir: {images_dir}')
plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9)
if args.push_to_hub:
trainer._add_patterns_to_gitignore(['images/'])
trainer.push_to_hub()
run_info = {
'memory': trainer.perf['memory'],
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,7 @@ def _init_training_args(self) -> None:
adam_epsilon=self.adam_epsilon,
hub_model_id=self.hub_model_id,
hub_private_repo=self.hub_private_repo,
push_hub_strategy=self.push_hub_strategy,
hub_strategy=self.push_hub_strategy,
hub_token=self.hub_token,
push_to_hub=self.push_to_hub,
resume_from_checkpoint=self.resume_from_checkpoint,
Expand Down
172 changes: 2 additions & 170 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,191 +22,23 @@
from transformers.data.data_collator import DataCollator
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import unwrap_model
from transformers.trainer import (ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, CONFIG_NAME,
PREFIX_CHECKPOINT_DIR, SAFE_WEIGHTS_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME,
WEIGHTS_NAME, IntervalStrategy, Trainer, TrainerCallback, is_peft_available)
from transformers.trainer import PREFIX_CHECKPOINT_DIR, TRAINER_STATE_NAME, Trainer, TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import TrainingArguments
from transformers.utils import is_sagemaker_mp_enabled, is_torch_npu_available

from swift.hub import Repository
from swift.hub.check_model import check_local_model_is_latest
from swift.torchacc_utils import (save_ta_ddp_checkpoint, save_ta_fsdp_checkpoint, ta_load_optimizer_and_scheduler,
ta_save_optimizer_and_scheduler, ta_trim_graph)
from swift.tuners import SwiftModel
from swift.utils import check_json_format, create_ms_repo, get_logger, use_torchacc
from swift.utils import check_json_format, get_logger, use_torchacc
from swift.utils.constants import Invoke
from .optimizers.galore import create_optimizer_and_scheduler
from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model

logger = get_logger()


def _push_to_hub(self: Repository, commit_message: str = 'Commit files to Modelscope Hub', **kwargs):
blocking = kwargs.get('blocking', True)
self.push(commit_message)
if not blocking:
# Compatible with transformers
return None, None
else:
return None


class PushToMsHubMixin:
repo: Repository

def _add_patterns_to_file(self, file_name: str, patterns: List[str], commit_message: Optional[str] = None) -> None:
# Make sure we only do this on the main process
if not self.is_world_process_zero():
return
if isinstance(patterns, str):
patterns = [patterns]
if commit_message is None:
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'

# Get current file content
repo_dir = self.repo.model_dir
file_path = os.path.join(repo_dir, file_name)
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
current_content = f.read()
else:
current_content = ''
# Add the patterns to file
content = current_content
for pattern in patterns:
if pattern not in content:
if len(content) > 0 and not content.endswith('\n'):
content += '\n'
content += f'{pattern}\n'

# Write the file if it has changed
if content != current_content:
with open(file_path, 'w', encoding='utf-8') as f:
logger.debug(f'Writing {file_name} file. Content: {content}')
f.write(content)
self.repo.push(commit_message)

def _add_patterns_to_gitignore(self, patterns: List[str], commit_message: Optional[str] = None) -> None:
self._add_patterns_to_file('.gitignore', patterns, commit_message)

def _add_patterns_to_gitattributes(self, patterns: List[str], commit_message: Optional[str] = None) -> None:
new_patterns = []
suffix = 'filter=lfs diff=lfs merge=lfs -text'
for pattern in patterns:
if suffix not in pattern:
pattern = f'{pattern} {suffix}'
new_patterns.append(pattern)
file_name = '.gitattributes'
if commit_message is None:
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
self._add_patterns_to_file(file_name, new_patterns, commit_message)

def init_hf_repo(self) -> None:
"""init ms repo. Compatible with transformers>=4.34"""
self.init_git_repo(at_init=True)

def init_git_repo(self, at_init: bool = False) -> None:
if not self.is_world_process_zero():
return
if (os.path.exists(self.args.output_dir) and os.listdir(self.args.output_dir) and self.args.overwrite_output_dir
and at_init):
# directory not empty.
shutil.rmtree(self.args.output_dir)
self.args.hub_model_id = create_ms_repo(self.args.hub_model_id, self.args.hub_token, self.args.hub_private_repo)
self.repo = Repository(self.args.output_dir, self.args.hub_model_id)
self._add_patterns_to_gitattributes(['*.safetensors', '*.bin', '*.pt'])
self.repo.push_to_hub = MethodType(_push_to_hub, self.repo)
self.repo.local_dir = self.repo.model_dir # hf compatibility

# By default, ignore the checkpoint folders
if self.args.push_hub_strategy != 'all_checkpoints':
self._add_patterns_to_gitignore(['checkpoint-*/', 'tmp-checkpoint-*/'])

# Add 'runs/' to .gitignore, ignore tensorboard files
self._add_patterns_to_gitignore(['runs/'])

# Add '*.sagemaker' to .gitignore if using SageMaker
if os.environ.get('SM_TRAINING_ENV'):
self._add_patterns_to_gitignore(['*.sagemaker-uploading', '*.sagemaker-uploaded'],
'Add `*.sagemaker` patterns to .gitignore')

self.push_in_progress = None

def push_to_hub(self, commit_message: str = 'End of training', **kwargs) -> None:
# user calls manually `push_to_hub` with `self.args.push_to_hub = False`
create_model_card = kwargs.pop('create_model_card', None)
if not hasattr(self, 'repo'):
self.init_git_repo()
self.save_model(_internal_call=True)

if not self.is_world_process_zero():
return

self.repo.push_to_hub(commit_message, **kwargs)
# push separately the model card to be independent from the rest of the model
readme_path = os.path.join(self.args.output_dir, 'README.md')
if create_model_card is None:
create_model_card = not os.path.exists(readme_path)
if create_model_card and self.args.should_save:
model_name = kwargs.pop('model_name', None)
if model_name is None and self.args.should_save:
if self.args.hub_model_id is not None:
model_name = self.args.hub_model_id.split('/')[-1]
else:
model_name = os.path.basename(self.args.output_dir)
self.create_model_card(model_name=model_name, **kwargs)
self.repo.push_to_hub('update model card README.md', **kwargs)

def _push_from_checkpoint(self, checkpoint_folder: str) -> None:
"""Compatible with transformers>=4.32"""
# Only push from one node.
if not self.is_world_process_zero() or self.args.push_hub_strategy == 'end':
return
output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

try:
if self.args.push_hub_strategy == 'checkpoint':
# Temporarily move the checkpoint just saved for the push
tmp_checkpoint = os.path.join(output_dir, 'last-checkpoint')
# We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
# subfolder.
if os.path.isdir(tmp_checkpoint):
shutil.rmtree(tmp_checkpoint)
shutil.move(checkpoint_folder, tmp_checkpoint)

if self.args.save_strategy == IntervalStrategy.STEPS:
commit_message = f'Training in progress, step {self.state.global_step}'
else:
commit_message = f'Training in progress, epoch {int(self.state.epoch)}'
if self.args.push_hub_strategy == 'push_best':
folder, checkpoint_name = os.path.split(checkpoint_folder)
checkpoint_name = checkpoint_name.replace('tmp-checkpoint-', 'checkpoint-')
last_model_checkpoint = os.path.join(folder, checkpoint_name)
if last_model_checkpoint == self.state.best_model_checkpoint:
self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
else:
self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
except Exception as e:
logger.error(f'Error when pushing to hub: {e}')
finally:
if self.args.push_hub_strategy == 'checkpoint':
# Move back the checkpoint to its place
shutil.move(tmp_checkpoint, checkpoint_folder)


class SwiftMixin:

def __init__(self,
Expand Down
179 changes: 179 additions & 0 deletions swift/trainers/push_to_ms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import tempfile
from concurrent.futures import Future
from functools import partial
from pathlib import Path
from typing import List, Optional, Union

from huggingface_hub import RepoUrl
from huggingface_hub.hf_api import CommitInfo, future_compatible
from modelscope import HubApi, push_to_hub
from modelscope.hub.api import ModelScopeConfig
from modelscope.hub.constants import ModelVisibility
from modelscope.hub.repository import Repository
from modelscope.hub.utils.utils import get_cache_dir
from requests.exceptions import HTTPError
from transformers.utils import logging, strtobool

logger = logging.get_logger(__name__)


class PushToMsHubMixin:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

处理了ms和hf,是不是应该叫 PushToHubMixin 更通用一点

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有处理hf,hf直接跳过这里走原来的逻辑


_use_hf_hub = strtobool(os.environ.get('USE_HF', 'False'))
_cache_dir = get_cache_dir()
_token = None

@staticmethod
def create_repo(repo_id: str, *, token: Union[str, bool, None] = None, private: bool = False, **kwargs) -> RepoUrl:
hub_model_id = PushToMsHubMixin._create_ms_repo(repo_id, token, private)
PushToMsHubMixin._token = token
with tempfile.TemporaryDirectory(dir=PushToMsHubMixin._cache_dir) as temp_cache_dir:
repo = Repository(temp_cache_dir, hub_model_id)
PushToMsHubMixin._add_patterns_to_gitattributes(repo, ['*.safetensors', '*.bin', '*.pt'])
# Add 'runs/' to .gitignore, ignore tensorboard files
PushToMsHubMixin._add_patterns_to_gitignore(repo, ['runs/', 'images/'])
PushToMsHubMixin._add_patterns_to_file(
repo,
'configuration.json', ['{"framework": "pytorch", "task": "text-generation", "allow_remote": true}'],
ignore_push_error=True)
# Add '*.sagemaker' to .gitignore if using SageMaker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个代码是从hf改的?怎么会想去特殊处理sagemaker环境。。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里来自于之前的代码,处理sagemaker可能更好一些,毕竟影响ms和lf两个框架

if os.environ.get('SM_TRAINING_ENV'):
PushToMsHubMixin._add_patterns_to_gitignore(repo, ['*.sagemaker-uploading', '*.sagemaker-uploaded'],
'Add `*.sagemaker` patterns to .gitignore')
return RepoUrl(url=hub_model_id, )

@staticmethod
@future_compatible
def upload_folder(
self,
*,
repo_id: str,
folder_path: Union[str, Path],
path_in_repo: Optional[str] = None,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
token: Union[str, bool, None] = None,
revision: Optional[str] = 'master',
ignore_patterns: Optional[Union[List[str], str]] = None,
run_as_future: bool = False,
**kwargs,
) -> Union[CommitInfo, str, Future[CommitInfo], Future[str]]:
commit_message = commit_message or 'Upload folder using api'
if commit_description:
commit_message = commit_message + '\n' + commit_description
if not os.path.exists(os.path.join(folder_path, 'configuration.json')):
with open(os.path.join(folder_path, 'configuration.json'), 'w') as f:
f.write('{"framework": "pytorch", "task": "text-generation", "allow_remote": true}')
if ignore_patterns:
ignore_patterns = [p for p in ignore_patterns if p != '_*']
if path_in_repo:
idx = folder_path.rfind(path_in_repo)
if idx >= 0:
folder_path = folder_path[:idx]
ignore_patterns = []
push_to_hub(
repo_id,
folder_path,
token or PushToMsHubMixin._token,
commit_message=commit_message,
ignore_file_pattern=ignore_patterns,
revision=revision,
tag=path_in_repo)
return CommitInfo(
commit_url=f'https://www.modelscope.cn/models/{repo_id}/files',
commit_message=commit_message,
commit_description=commit_description,
oid=None,
)

if not _use_hf_hub:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么not use hf的时候处理了hf专门的逻辑,是写反了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没写反,这里直接hack底层的create/upload逻辑

import huggingface_hub
from huggingface_hub.hf_api import api
from transformers import trainer
huggingface_hub.create_repo = create_repo
huggingface_hub.upload_folder = partial(upload_folder, api)
trainer.create_repo = create_repo
trainer.upload_folder = partial(upload_folder, api)

@staticmethod
def _create_ms_repo(hub_model_id: str, hub_token: Optional[str] = None, hub_private_repo: bool = False) -> str:
assert hub_model_id is not None, 'Please enter a valid hub_model_id'

api = HubApi()
if hub_token is None:
hub_token = os.environ.get('MODELSCOPE_API_TOKEN')
if hub_token is not None:
api.login(hub_token)
visibility = ModelVisibility.PRIVATE if hub_private_repo else ModelVisibility.PUBLIC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else要报错吧。没有token是push不了的

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


if '/' not in hub_model_id:
user_name = ModelScopeConfig.get_user_info()[0]
assert isinstance(user_name, str)
hub_model_id = f'{user_name}/{hub_model_id}'
logger.info(f"'/' not in hub_model_id, setting hub_model_id: {hub_model_id}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'/' not in hub_model_id, pushing to personal repo {hub_model_id}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

try:
api.create_model(hub_model_id, visibility)
except HTTPError:
# The remote repository has been created
pass
return hub_model_id

@staticmethod
def _add_patterns_to_file(repo: Repository,
file_name: str,
patterns: List[str],
commit_message: Optional[str] = None,
ignore_push_error=False) -> None:
if isinstance(patterns, str):
patterns = [patterns]
if commit_message is None:
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'

# Get current file content
repo_dir = repo.model_dir
file_path = os.path.join(repo_dir, file_name)
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
current_content = f.read()
else:
current_content = ''
# Add the patterns to file
content = current_content
for pattern in patterns:
if pattern not in content:
if len(content) > 0 and not content.endswith('\n'):
content += '\n'
content += f'{pattern}\n'

# Write the file if it has changed
if content != current_content:
with open(file_path, 'w', encoding='utf-8') as f:
logger.debug(f'Writing {file_name} file. Content: {content}')
f.write(content)
try:
repo.push(commit_message)
except Exception as e:
if ignore_push_error:
pass
else:
raise e

@staticmethod
def _add_patterns_to_gitignore(repo: Repository, patterns: List[str], commit_message: Optional[str] = None) -> None:
PushToMsHubMixin._add_patterns_to_file(repo, '.gitignore', patterns, commit_message, ignore_push_error=True)

@staticmethod
def _add_patterns_to_gitattributes(repo: Repository,
patterns: List[str],
commit_message: Optional[str] = None) -> None:
new_patterns = []
suffix = 'filter=lfs diff=lfs merge=lfs -text'
for pattern in patterns:
if suffix not in pattern:
pattern = f'{pattern} {suffix}'
new_patterns.append(pattern)
file_name = '.gitattributes'
if commit_message is None:
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
PushToMsHubMixin._add_patterns_to_file(repo, file_name, new_patterns, commit_message, ignore_push_error=True)
Loading
Loading