2727
2828from . import __version__
2929from .dynamic_module_utils import custom_object_save
30- from .utils import CONFIG_NAME , PushToHubMixin , cached_file , copy_func , is_torch_available , logging
30+ from .utils import (
31+ CONFIG_NAME ,
32+ PushToHubMixin ,
33+ cached_file ,
34+ copy_func ,
35+ extract_commit_hash ,
36+ is_torch_available ,
37+ logging ,
38+ )
3139
3240
3341logger = logging .get_logger (__name__ )
@@ -343,6 +351,8 @@ def __init__(self, **kwargs):
343351
344352 # Name or path to the pretrained checkpoint
345353 self ._name_or_path = str (kwargs .pop ("name_or_path" , "" ))
354+ # Config hash
355+ self ._commit_hash = kwargs .pop ("_commit_hash" , None )
346356
347357 # Drop the transformers version info
348358 self .transformers_version = kwargs .pop ("transformers_version" , None )
@@ -539,6 +549,8 @@ def get_config_dict(
539549 original_kwargs = copy .deepcopy (kwargs )
540550 # Get config dict associated with the base config file
541551 config_dict , kwargs = cls ._get_config_dict (pretrained_model_name_or_path , ** kwargs )
552+ if "_commit_hash" in config_dict :
553+ original_kwargs ["_commit_hash" ] = config_dict ["_commit_hash" ]
542554
543555 # That config file may point us toward another config file to use.
544556 if "configuration_files" in config_dict :
@@ -564,6 +576,7 @@ def _get_config_dict(
564576 subfolder = kwargs .pop ("subfolder" , "" )
565577 from_pipeline = kwargs .pop ("_from_pipeline" , None )
566578 from_auto_class = kwargs .pop ("_from_auto" , False )
579+ commit_hash = kwargs .pop ("_commit_hash" , None )
567580
568581 if trust_remote_code is True :
569582 logger .warning (
@@ -599,7 +612,9 @@ def _get_config_dict(
599612 user_agent = user_agent ,
600613 revision = revision ,
601614 subfolder = subfolder ,
615+ _commit_hash = commit_hash ,
602616 )
617+ commit_hash = extract_commit_hash (resolved_config_file , commit_hash )
603618 except EnvironmentError :
604619 # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
605620 # the original exception.
@@ -616,6 +631,7 @@ def _get_config_dict(
616631 try :
617632 # Load config dict
618633 config_dict = cls ._dict_from_json_file (resolved_config_file )
634+ config_dict ["_commit_hash" ] = commit_hash
619635 except (json .JSONDecodeError , UnicodeDecodeError ):
620636 raise EnvironmentError (
621637 f"It looks like the config file at '{ resolved_config_file } ' is not a valid JSON file."
@@ -648,6 +664,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
648664 # We remove them so they don't appear in `return_unused_kwargs`.
649665 kwargs .pop ("_from_auto" , None )
650666 kwargs .pop ("_from_pipeline" , None )
667+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
668+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict :
669+ kwargs ["_commit_hash" ] = config_dict ["_commit_hash" ]
651670
652671 config = cls (** config_dict )
653672
@@ -751,6 +770,8 @@ def to_dict(self) -> Dict[str, Any]:
751770 output ["model_type" ] = self .__class__ .model_type
752771 if "_auto_class" in output :
753772 del output ["_auto_class" ]
773+ if "_commit_hash" in output :
774+ del output ["_commit_hash" ]
754775
755776 # Transformers version when serializing the model
756777 output ["transformers_version" ] = __version__
0 commit comments