-
Notifications
You must be signed in to change notification settings - Fork 828
Refactor push_to_hub #1883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor push_to_hub #1883
Changes from 7 commits
215b659
e643be4
bfd8b8c
79bb977
2e03737
b567098
8d0c05b
0633cc1
cd56717
f977020
49d59ab
5f63eb8
f1661a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import os | ||
import tempfile | ||
from concurrent.futures import Future | ||
from functools import partial | ||
from pathlib import Path | ||
from typing import List, Optional, Union | ||
|
||
from huggingface_hub import RepoUrl | ||
from huggingface_hub.hf_api import CommitInfo, future_compatible | ||
from modelscope import HubApi, push_to_hub | ||
from modelscope.hub.api import ModelScopeConfig | ||
from modelscope.hub.constants import ModelVisibility | ||
from modelscope.hub.repository import Repository | ||
from modelscope.hub.utils.utils import get_cache_dir | ||
from requests.exceptions import HTTPError | ||
from transformers.utils import logging, strtobool | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class PushToMsHubMixin: | ||
|
||
_use_hf_hub = strtobool(os.environ.get('USE_HF', 'False')) | ||
_cache_dir = get_cache_dir() | ||
_token = None | ||
|
||
@staticmethod | ||
def create_repo(repo_id: str, *, token: Union[str, bool, None] = None, private: bool = False, **kwargs) -> RepoUrl: | ||
hub_model_id = PushToMsHubMixin._create_ms_repo(repo_id, token, private) | ||
PushToMsHubMixin._token = token | ||
with tempfile.TemporaryDirectory(dir=PushToMsHubMixin._cache_dir) as temp_cache_dir: | ||
repo = Repository(temp_cache_dir, hub_model_id) | ||
PushToMsHubMixin._add_patterns_to_gitattributes(repo, ['*.safetensors', '*.bin', '*.pt']) | ||
# Add 'runs/' to .gitignore, ignore tensorboard files | ||
PushToMsHubMixin._add_patterns_to_gitignore(repo, ['runs/', 'images/']) | ||
PushToMsHubMixin._add_patterns_to_file( | ||
repo, | ||
'configuration.json', ['{"framework": "pytorch", "task": "text-generation", "allow_remote": true}'], | ||
ignore_push_error=True) | ||
# Add '*.sagemaker' to .gitignore if using SageMaker | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个代码是从hf改的?怎么会想去特殊处理sagemaker环境。。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,这里来自于之前的代码,处理sagemaker可能更好一些,毕竟影响ms和lf两个框架 |
||
if os.environ.get('SM_TRAINING_ENV'): | ||
PushToMsHubMixin._add_patterns_to_gitignore(repo, ['*.sagemaker-uploading', '*.sagemaker-uploaded'], | ||
'Add `*.sagemaker` patterns to .gitignore') | ||
return RepoUrl(url=hub_model_id, ) | ||
|
||
@staticmethod | ||
@future_compatible | ||
def upload_folder( | ||
self, | ||
*, | ||
repo_id: str, | ||
folder_path: Union[str, Path], | ||
path_in_repo: Optional[str] = None, | ||
commit_message: Optional[str] = None, | ||
commit_description: Optional[str] = None, | ||
token: Union[str, bool, None] = None, | ||
revision: Optional[str] = 'master', | ||
ignore_patterns: Optional[Union[List[str], str]] = None, | ||
run_as_future: bool = False, | ||
**kwargs, | ||
) -> Union[CommitInfo, str, Future[CommitInfo], Future[str]]: | ||
commit_message = commit_message or 'Upload folder using api' | ||
if commit_description: | ||
commit_message = commit_message + '\n' + commit_description | ||
if not os.path.exists(os.path.join(folder_path, 'configuration.json')): | ||
with open(os.path.join(folder_path, 'configuration.json'), 'w') as f: | ||
f.write('{"framework": "pytorch", "task": "text-generation", "allow_remote": true}') | ||
if ignore_patterns: | ||
ignore_patterns = [p for p in ignore_patterns if p != '_*'] | ||
if path_in_repo: | ||
idx = folder_path.rfind(path_in_repo) | ||
if idx >= 0: | ||
folder_path = folder_path[:idx] | ||
ignore_patterns = [] | ||
push_to_hub( | ||
repo_id, | ||
folder_path, | ||
token or PushToMsHubMixin._token, | ||
commit_message=commit_message, | ||
ignore_file_pattern=ignore_patterns, | ||
revision=revision, | ||
tag=path_in_repo) | ||
return CommitInfo( | ||
commit_url=f'https://www.modelscope.cn/models/{repo_id}/files', | ||
commit_message=commit_message, | ||
commit_description=commit_description, | ||
oid=None, | ||
) | ||
|
||
if not _use_hf_hub: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么not use hf的时候处理了hf专门的逻辑,是写反了? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 没写反,这里直接hack底层的create/upload逻辑 |
||
import huggingface_hub | ||
from huggingface_hub.hf_api import api | ||
from transformers import trainer | ||
huggingface_hub.create_repo = create_repo | ||
huggingface_hub.upload_folder = partial(upload_folder, api) | ||
trainer.create_repo = create_repo | ||
trainer.upload_folder = partial(upload_folder, api) | ||
|
||
@staticmethod | ||
def _create_ms_repo(hub_model_id: str, hub_token: Optional[str] = None, hub_private_repo: bool = False) -> str: | ||
assert hub_model_id is not None, 'Please enter a valid hub_model_id' | ||
|
||
api = HubApi() | ||
if hub_token is None: | ||
hub_token = os.environ.get('MODELSCOPE_API_TOKEN') | ||
if hub_token is not None: | ||
api.login(hub_token) | ||
visibility = ModelVisibility.PRIVATE if hub_private_repo else ModelVisibility.PUBLIC | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. else要报错吧。没有token是push不了的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
|
||
if '/' not in hub_model_id: | ||
user_name = ModelScopeConfig.get_user_info()[0] | ||
assert isinstance(user_name, str) | ||
hub_model_id = f'{user_name}/{hub_model_id}' | ||
logger.info(f"'/' not in hub_model_id, setting hub_model_id: {hub_model_id}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. '/' not in hub_model_id, pushing to personal repo {hub_model_id} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
try: | ||
api.create_model(hub_model_id, visibility) | ||
except HTTPError: | ||
# The remote repository has been created | ||
pass | ||
return hub_model_id | ||
|
||
@staticmethod | ||
def _add_patterns_to_file(repo: Repository, | ||
file_name: str, | ||
patterns: List[str], | ||
commit_message: Optional[str] = None, | ||
ignore_push_error=False) -> None: | ||
if isinstance(patterns, str): | ||
patterns = [patterns] | ||
if commit_message is None: | ||
commit_message = f'Add `{patterns[0]}` patterns to {file_name}' | ||
|
||
# Get current file content | ||
repo_dir = repo.model_dir | ||
file_path = os.path.join(repo_dir, file_name) | ||
if os.path.exists(file_path): | ||
with open(file_path, 'r', encoding='utf-8') as f: | ||
current_content = f.read() | ||
else: | ||
current_content = '' | ||
# Add the patterns to file | ||
content = current_content | ||
for pattern in patterns: | ||
if pattern not in content: | ||
if len(content) > 0 and not content.endswith('\n'): | ||
content += '\n' | ||
content += f'{pattern}\n' | ||
|
||
# Write the file if it has changed | ||
if content != current_content: | ||
with open(file_path, 'w', encoding='utf-8') as f: | ||
logger.debug(f'Writing {file_name} file. Content: {content}') | ||
f.write(content) | ||
try: | ||
repo.push(commit_message) | ||
except Exception as e: | ||
if ignore_push_error: | ||
pass | ||
else: | ||
raise e | ||
|
||
@staticmethod | ||
def _add_patterns_to_gitignore(repo: Repository, patterns: List[str], commit_message: Optional[str] = None) -> None: | ||
PushToMsHubMixin._add_patterns_to_file(repo, '.gitignore', patterns, commit_message, ignore_push_error=True) | ||
|
||
@staticmethod | ||
def _add_patterns_to_gitattributes(repo: Repository, | ||
patterns: List[str], | ||
commit_message: Optional[str] = None) -> None: | ||
new_patterns = [] | ||
suffix = 'filter=lfs diff=lfs merge=lfs -text' | ||
for pattern in patterns: | ||
if suffix not in pattern: | ||
pattern = f'{pattern} {suffix}' | ||
new_patterns.append(pattern) | ||
file_name = '.gitattributes' | ||
if commit_message is None: | ||
commit_message = f'Add `{patterns[0]}` patterns to {file_name}' | ||
PushToMsHubMixin._add_patterns_to_file(repo, file_name, new_patterns, commit_message, ignore_push_error=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
处理了ms和hf,是不是应该叫 PushToHubMixin 更通用一点
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有处理hf,hf直接跳过这里走原来的逻辑