Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,27 @@ def _post_init(self, original_init, *args, **kwargs):
init_dict = fn_args_to_dict(original_init, *((self,) + args), **kwargs)
self.config = init_dict

def __getattr__(self, name):
"""
called when the attribute name is missed in the model

Args:
name: the name of attribute

Returns: the value of attribute

"""
try:
return super(PretrainedModel, self).__getattr__(name)
except AttributeError:
result = getattr(self.config, name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以说num_classes -> num_labels这种attribute map就交给底层的PretrainedConfig来支持了是吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的代码逻辑是:如果 model 获取不到,就从 config 获取,并且抛出warning。


logger.warning(
f"Do not access config from `model.{name}` which will be deprecated after v2.6.0, "
f"Instead, do `model.config.{name}`"
)
return result

@property
def base_model(self):
"""
Expand Down
26 changes: 26 additions & 0 deletions tests/transformers/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class ModelTesterMixin:
test_resize_position_embeddings = False
test_mismatched_shapes = True
test_missing_keys = True
test_model_compatibility_keys = False
use_test_inputs_embeds = False
use_test_model_name_list = True
is_encoder_decoder = False
Expand Down Expand Up @@ -525,6 +526,31 @@ def random_choice_pretrained_config_field(self) -> Optional[str]:
fields = [key for key, value in config.to_dict() if value]
return random.choice(fields)

def test_for_missed_attribute(self):
if not self.test_model_compatibility_keys:
self.skipTest(f"Do not test model_compatibility_keys on {self.base_model_class}")
return

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.constructed_from_pretrained_config():
continue

model = self._make_model_instance(config, model_class)

all_maps: dict = copy.deepcopy(model_class.config_class.attribute_map)
all_maps.update(model_class.config_class.standard_config_map)

for old_attribute, new_attribute in all_maps.items():
old_value = getattr(model, old_attribute)
new_value = getattr(model, new_attribute)

# eg: dropout can be an instance of nn.Dropout, so we should check it attribute
if type(new_value) != type(old_value):
continue

self.assertEqual(old_value, new_value)


class ModelTesterPretrainedMixin:
base_model_class: PretrainedModel = None
Expand Down