Skip to content

Commit 5d58467

Browse files
Tavish9zucchini-nlp
authored andcommitted
Convert _VALID_DICT_FIELDS to class attribute for shared dict parsing in subclasses (huggingface#36736)
* make _VALID_DICT_FIELDS as a class attribute * fix test case about TrainingArguments
1 parent 374dce1 commit 5d58467

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

src/transformers/training_args.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,6 @@ class OptimizerNames(ExplicitEnum):
188188
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
189189

190190

191-
# Sometimes users will pass in a `str` repr of a dict in the CLI
192-
# We need to track what fields those can be. Each time a new arg
193-
# has a dict type, it must be added to this list.
194-
# Important: These should be typed with Optional[Union[dict,str,...]]
195-
_VALID_DICT_FIELDS = [
196-
"accelerator_config",
197-
"fsdp_config",
198-
"deepspeed",
199-
"gradient_checkpointing_kwargs",
200-
"lr_scheduler_kwargs",
201-
]
202-
203-
204191
def _convert_str_dict(passed_value: dict):
205192
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
206193
for key, value in passed_value.items():
@@ -814,6 +801,18 @@ class TrainingArguments:
814801
https://github.com/huggingface/transformers/issues/34242
815802
"""
816803

804+
# Sometimes users will pass in a `str` repr of a dict in the CLI
805+
# We need to track what fields those can be. Each time a new arg
806+
# has a dict type, it must be added to this list.
807+
# Important: These should be typed with Optional[Union[dict,str,...]]
808+
_VALID_DICT_FIELDS = [
809+
"accelerator_config",
810+
"fsdp_config",
811+
"deepspeed",
812+
"gradient_checkpointing_kwargs",
813+
"lr_scheduler_kwargs",
814+
]
815+
817816
framework = "pt"
818817
output_dir: Optional[str] = field(
819818
default=None,
@@ -1561,7 +1560,7 @@ def __post_init__(self):
15611560
)
15621561

15631562
# Parse in args that could be `dict` sent in from the CLI as a string
1564-
for field in _VALID_DICT_FIELDS:
1563+
for field in self._VALID_DICT_FIELDS:
15651564
passed_value = getattr(self, field)
15661565
# We only want to do this if the str starts with a bracket to indicate a `dict`
15671566
# else its likely a filename if supported

tests/utils/test_hf_argparser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from transformers import HfArgumentParser, TrainingArguments
3030
from transformers.hf_argparser import make_choice_type_function, string_to_bool
3131
from transformers.testing_utils import require_torch
32-
from transformers.training_args import _VALID_DICT_FIELDS
3332

3433

3534
# Since Python 3.10, we can use the builtin `|` operator for Union types
@@ -412,7 +411,8 @@ def test_parse_yaml(self):
412411
args = BasicExample(**args_dict_for_yaml)
413412
self.assertEqual(parsed_args, args)
414413

415-
def test_integration_training_args(self):
414+
def test_z_integration_training_args(self):
415+
# make sure that this test executes last in the test suite
416416
parser = HfArgumentParser(TrainingArguments)
417417
self.assertIsNotNone(parser)
418418

@@ -424,7 +424,7 @@ def test_valid_dict_annotation(self):
424424
If this fails, a type annotation change is
425425
needed on a new input
426426
"""
427-
base_list = _VALID_DICT_FIELDS.copy()
427+
base_list = TrainingArguments._VALID_DICT_FIELDS.copy()
428428
args = TrainingArguments
429429

430430
# First find any annotations that contain `dict`
@@ -468,7 +468,7 @@ def test_valid_dict_annotation(self):
468468
self.assertIn(
469469
field.name,
470470
base_list,
471-
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
471+
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `TrainingArguments._VALID_DICT_FIELDS`",
472472
)
473473

474474
@require_torch

0 commit comments

Comments
 (0)