Skip to content

Commit 3354848

Browse files
authored
[AutoNLP] Add english models for text classification (#4704)
* add english models * add tests
1 parent b5b4ce4 commit 3354848

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

paddlenlp/experimental/autonlp/auto_trainer_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,21 @@ def __init__(
6767
self.train_dataset = train_dataset
6868
self.eval_dataset = eval_dataset
6969
self.greater_is_better = greater_is_better
70+
if language not in self.supported_languages:
71+
raise ValueError(
72+
f"'{language}' is not supported. Please choose among the following: {self.supported_languages}"
73+
)
74+
7075
self.language = language
7176
self.output_dir = output_dir
7277

78+
@property
79+
@abstractmethod
80+
def supported_languages(self) -> List[str]:
81+
"""
82+
Override to store the supported languages for each auto trainer class
83+
"""
84+
7385
@property
7486
@abstractmethod
7587
def _default_training_argument(self) -> TrainingArguments:

paddlenlp/experimental/autonlp/text_classification.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def __init__(
8989
f"'{problem_type}' is not a supported problem_type. Please select among ['multi_label', 'multi_class']"
9090
)
9191

92+
@property
93+
def supported_languages(self) -> List[str]:
94+
return ["Chinese", "English"]
95+
9296
@property
9397
def _default_training_argument(self) -> TrainingArguments:
9498
return TrainingArguments(
@@ -129,19 +133,42 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
129133
"ernie-3.0-nano-zh", # 4-layer, 312-hidden, 12-heads, 18M parameters.
130134
],
131135
)
136+
english_models = hp.choice(
137+
"models",
138+
[
139+
# add deberta-v3 when we have it
140+
"roberta-large", # 24-layer, 1024-hidden, 16-heads, 334M parameters. Case-sensitive
141+
"roberta-base", # 12-layer, 768-hidden, 12-heads, 110M parameters. Case-sensitive
142+
"distilroberta-base", # 6-layer, 768-hidden, 12-heads, 66M parameters. Case-sensitive
143+
"ernie-2.0-base-en", # 12-layer, 768-hidden, 12-heads, 103M parameters. Trained on lower-cased English text.
144+
"ernie-2.0-large-en", # 24-layer, 1024-hidden, 16-heads, 336M parameters. Trained on lower-cased English text.
145+
"distilbert-base-uncased", # 6-layer, 768-hidden, 12-heads, 66M parameters
146+
],
147+
)
132148
return [
133149
# fast learning: high LR, small early stop patience
134150
{
135151
"preset": "finetune",
136152
"language": "Chinese",
137153
"trainer_type": "Trainer",
138-
"EarlyStoppingCallback.early_stopping_patience": 2,
154+
"EarlyStoppingCallback.early_stopping_patience": 5,
139155
"TrainingArguments.per_device_train_batch_size": train_batch_size,
140156
"TrainingArguments.per_device_eval_batch_size": train_batch_size * 2,
141157
"TrainingArguments.num_train_epochs": 100,
142158
"TrainingArguments.model_name_or_path": chinese_models,
143159
"TrainingArguments.learning_rate": 3e-5,
144160
},
161+
{
162+
"preset": "finetune",
163+
"language": "English",
164+
"trainer_type": "Trainer",
165+
"EarlyStoppingCallback.early_stopping_patience": 5,
166+
"TrainingArguments.per_device_train_batch_size": train_batch_size,
167+
"TrainingArguments.per_device_eval_batch_size": train_batch_size * 2,
168+
"TrainingArguments.num_train_epochs": 100,
169+
"TrainingArguments.model_name_or_path": english_models,
170+
"TrainingArguments.learning_rate": 3e-5,
171+
},
145172
# slow learning: small LR, large early stop patience
146173
{
147174
"preset": "finetune",
@@ -154,6 +181,17 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
154181
"TrainingArguments.model_name_or_path": chinese_models,
155182
"TrainingArguments.learning_rate": 5e-6,
156183
},
184+
{
185+
"preset": "finetune",
186+
"language": "English",
187+
"trainer_type": "Trainer",
188+
"EarlyStoppingCallback.early_stopping_patience": 5,
189+
"TrainingArguments.per_device_train_batch_size": train_batch_size,
190+
"TrainingArguments.per_device_eval_batch_size": train_batch_size * 2,
191+
"TrainingArguments.num_train_epochs": 100,
192+
"TrainingArguments.model_name_or_path": english_models,
193+
"TrainingArguments.learning_rate": 5e-6,
194+
},
157195
# Note: prompt tuning candidates not included for now due to lack of inference capability
158196
]
159197

tests/experimental/autonlp/test_text_classification.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,37 @@ def test_untrained_auto_trainer(self):
268268
# test export
269269
auto_trainer.export(temp_dir)
270270

271+
def test_unsupported_languages(self):
272+
with TemporaryDirectory() as temp_dir:
273+
train_ds = copy.deepcopy(self.multi_class_train_ds)
274+
dev_ds = copy.deepcopy(self.multi_class_dev_ds)
275+
with self.assertRaises(ValueError):
276+
AutoTrainerForTextClassification(
277+
train_dataset=train_ds,
278+
eval_dataset=dev_ds,
279+
label_column="label_desc",
280+
text_column="sentence",
281+
language="Spanish", # spanish is unsupported for now
282+
output_dir=temp_dir,
283+
)
284+
285+
def test_model_language_filter(self):
286+
with TemporaryDirectory() as temp_dir:
287+
train_ds = copy.deepcopy(self.multi_class_train_ds)
288+
dev_ds = copy.deepcopy(self.multi_class_dev_ds)
289+
auto_trainer = AutoTrainerForTextClassification(
290+
train_dataset=train_ds,
291+
eval_dataset=dev_ds,
292+
label_column="label_desc",
293+
text_column="sentence",
294+
language="Chinese",
295+
output_dir=temp_dir,
296+
)
297+
for language in auto_trainer.supported_languages:
298+
model_candidates = auto_trainer._filter_model_candidates(language=language)
299+
for candidate in model_candidates:
300+
self.assertEqual(candidate["language"], language)
301+
271302

272303
if __name__ == "__main__":
273304
unittest.main()

0 commit comments

Comments
 (0)