2424import re
2525import warnings
2626from collections .abc import Mapping
27+ from pathlib import Path
2728from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
2829
2930import h5py
5859 RepositoryNotFoundError ,
5960 RevisionNotFoundError ,
6061 cached_path ,
61- copy_func ,
6262 find_labels ,
6363 has_file ,
6464 hf_bucket_url ,
6565 is_offline_mode ,
6666 is_remote_url ,
6767 logging ,
6868 requires_backends ,
69+ working_or_temp_dir ,
6970)
7071
7172
@@ -1919,6 +1920,7 @@ def save_pretrained(
19191920 version = 1 ,
19201921 push_to_hub = False ,
19211922 max_shard_size : Union [int , str ] = "10GB" ,
1923+ create_pr : bool = False ,
19221924 ** kwargs
19231925 ):
19241926 """
@@ -1935,16 +1937,9 @@ def save_pretrained(
19351937 TensorFlow Serving as detailed in the official documentation
19361938 https://www.tensorflow.org/tfx/serving/serving_basic
19371939 push_to_hub (`bool`, *optional*, defaults to `False`):
1938- Whether or not to push your model to the Hugging Face model hub after saving it.
1939-
1940- <Tip warning={true}>
1941-
1942- Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
1943- which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
1944- folder. Pass along `temp_dir=True` to use a temporary directory instead.
1945-
1946- </Tip>
1947-
1940+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
1941+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
1942+ namespace).
19481943 max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
19491944 The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
19501945 lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
@@ -1956,18 +1951,23 @@ def save_pretrained(
19561951
19571952 </Tip>
19581953
1954+ create_pr (`bool`, *optional*, defaults to `False`):
1955+ Whether or not to create a PR with the uploaded files or directly commit.
1956+
19591957 kwargs:
19601958 Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
19611959 """
19621960 if os .path .isfile (save_directory ):
19631961 logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
19641962 return
19651963
1964+ os .makedirs (save_directory , exist_ok = True )
1965+
19661966 if push_to_hub :
19671967 commit_message = kwargs .pop ("commit_message" , None )
1968- repo = self . _create_or_get_repo ( save_directory , ** kwargs )
1969-
1970- os . makedirs (save_directory , exist_ok = True )
1968+ repo_id = kwargs . pop ( "repo_id" , save_directory . split ( os . path . sep )[ - 1 ] )
1969+ repo_id , token = self . _create_repo ( repo_id , ** kwargs )
1970+ files_timestamps = self . _get_files_timestamps (save_directory )
19711971
19721972 if saved_model :
19731973 saved_model_dir = os .path .join (save_directory , "saved_model" , str (version ))
@@ -2030,8 +2030,9 @@ def save_pretrained(
20302030 param_dset [:] = layer .numpy ()
20312031
20322032 if push_to_hub :
2033- url = self ._push_to_hub (repo , commit_message = commit_message )
2034- logger .info (f"Model pushed to the hub in this commit: { url } " )
2033+ self ._upload_modified_files (
2034+ save_directory , repo_id , files_timestamps , commit_message = commit_message , token = token
2035+ )
20352036
20362037 @classmethod
20372038 def from_pretrained (cls , pretrained_model_name_or_path , * model_args , ** kwargs ):
@@ -2475,12 +2476,95 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
24752476
24762477 return model
24772478
2479+ def push_to_hub (
2480+ self ,
2481+ repo_id : str ,
2482+ use_temp_dir : Optional [bool ] = None ,
2483+ commit_message : Optional [str ] = None ,
2484+ private : Optional [bool ] = None ,
2485+ use_auth_token : Optional [Union [bool , str ]] = None ,
2486+ max_shard_size : Optional [Union [int , str ]] = "10GB" ,
2487+ ** model_card_kwargs
2488+ ) -> str :
2489+ """
2490+ Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
24782491
2479- # To update the docstring, we need to copy the method, otherwise we change the original docstring.
2480- TFPreTrainedModel .push_to_hub = copy_func (TFPreTrainedModel .push_to_hub )
2481- TFPreTrainedModel .push_to_hub .__doc__ = TFPreTrainedModel .push_to_hub .__doc__ .format (
2482- object = "model" , object_class = "TFAutoModel" , object_files = "model checkpoint"
2483- )
2492+ Parameters:
2493+ repo_id (`str`):
2494+ The name of the repository you want to push your model to. It should contain your organization name
2495+ when pushing to a given organization.
2496+ use_temp_dir (`bool`, *optional*):
2497+ Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
2498+ Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
2499+ commit_message (`str`, *optional*):
2500+ Message to commit while pushing. Will default to `"Upload model"`.
2501+ private (`bool`, *optional*):
2502+ Whether or not the repository created should be private (requires a paying subscription).
2503+ use_auth_token (`bool` or `str`, *optional*):
2504+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
2505+ when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if
2506+ `repo_url` is not specified.
2507+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
2508+ Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
2509+ will then be each of size lower than this size. If expressed as a string, needs to be digits followed
2510+ by a unit (like `"5MB"`).
2511+ model_card_kwargs:
2512+ Additional keyword arguments passed along to the [`~TFPreTrainedModel.create_model_card`] method.
2513+
2514+ Examples:
2515+
2516+ ```python
2517+ from transformers import TFAutoModel
2518+
2519+ model = TFAutoModel.from_pretrained("bert-base-cased")
2520+
2521+ # Push the model to your namespace with the name "my-finetuned-bert".
2522+ model.push_to_hub("my-finetuned-bert")
2523+
2524+ # Push the model to an organization with the name "my-finetuned-bert".
2525+ model.push_to_hub("huggingface/my-finetuned-bert")
2526+ ```
2527+ """
2528+ if "repo_path_or_name" in model_card_kwargs :
2529+ warnings .warn (
2530+ "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
2531+ "`repo_id` instead."
2532+ )
2533+ repo_id = model_card_kwargs .pop ("repo_path_or_name" )
2534+ # Deprecation warning will be sent after for repo_url and organization
2535+ repo_url = model_card_kwargs .pop ("repo_url" , None )
2536+ organization = model_card_kwargs .pop ("organization" , None )
2537+
2538+ if os .path .isdir (repo_id ):
2539+ working_dir = repo_id
2540+ repo_id = repo_id .split (os .path .sep )[- 1 ]
2541+ else :
2542+ working_dir = repo_id .split ("/" )[- 1 ]
2543+
2544+ repo_id , token = self ._create_repo (
2545+ repo_id , private = private , use_auth_token = use_auth_token , repo_url = repo_url , organization = organization
2546+ )
2547+
2548+ if use_temp_dir is None :
2549+ use_temp_dir = not os .path .isdir (working_dir )
2550+
2551+ with working_or_temp_dir (working_dir = working_dir , use_temp_dir = use_temp_dir ) as work_dir :
2552+ files_timestamps = self ._get_files_timestamps (work_dir )
2553+
2554+ # Save all files.
2555+ self .save_pretrained (work_dir , max_shard_size = max_shard_size )
2556+ if hasattr (self , "history" ) and hasattr (self , "create_model_card" ):
2557+ # This is a Keras model and we might be able to fish out its History and make a model card out of it
2558+ base_model_card_args = {
2559+ "output_dir" : work_dir ,
2560+ "model_name" : Path (repo_id ).name ,
2561+ }
2562+ base_model_card_args .update (model_card_kwargs )
2563+ self .create_model_card (** base_model_card_args )
2564+
2565+ self ._upload_modified_files (
2566+ work_dir , repo_id , files_timestamps , commit_message = commit_message , token = token
2567+ )
24842568
24852569
24862570class TFConv1D (tf .keras .layers .Layer ):
0 commit comments