@@ -89,6 +89,10 @@ def __init__(
8989 f"'{ problem_type } ' is not a supported problem_type. Please select among ['multi_label', 'multi_class']"
9090 )
9191
92+ @property
93+ def supported_languages (self ) -> List [str ]:
94+ return ["Chinese" , "English" ]
95+
9296 @property
9397 def _default_training_argument (self ) -> TrainingArguments :
9498 return TrainingArguments (
@@ -129,19 +133,42 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
129133 "ernie-3.0-nano-zh" , # 4-layer, 312-hidden, 12-heads, 18M parameters.
130134 ],
131135 )
136+ english_models = hp .choice (
137+ "models" ,
138+ [
139+ # add deberta-v3 when we have it
140+ "roberta-large" , # 24-layer, 1024-hidden, 16-heads, 334M parameters. Case-sensitive
141+ "roberta-base" , # 12-layer, 768-hidden, 12-heads, 110M parameters. Case-sensitive
142+ "distilroberta-base" , # 6-layer, 768-hidden, 12-heads, 66M parameters. Case-sensitive
143+ "ernie-2.0-base-en" , # 12-layer, 768-hidden, 12-heads, 103M parameters. Trained on lower-cased English text.
144+ "ernie-2.0-large-en" , # 24-layer, 1024-hidden, 16-heads, 336M parameters. Trained on lower-cased English text.
145+ "distilbert-base-uncased" , # 6-layer, 768-hidden, 12-heads, 66M parameters
146+ ],
147+ )
132148 return [
133149 # fast learning: high LR, small early stop patience
134150 {
135151 "preset" : "finetune" ,
136152 "language" : "Chinese" ,
137153 "trainer_type" : "Trainer" ,
138- "EarlyStoppingCallback.early_stopping_patience" : 2 ,
154+ "EarlyStoppingCallback.early_stopping_patience" : 5 ,
139155 "TrainingArguments.per_device_train_batch_size" : train_batch_size ,
140156 "TrainingArguments.per_device_eval_batch_size" : train_batch_size * 2 ,
141157 "TrainingArguments.num_train_epochs" : 100 ,
142158 "TrainingArguments.model_name_or_path" : chinese_models ,
143159 "TrainingArguments.learning_rate" : 3e-5 ,
144160 },
161+ {
162+ "preset" : "finetune" ,
163+ "language" : "English" ,
164+ "trainer_type" : "Trainer" ,
165+ "EarlyStoppingCallback.early_stopping_patience" : 5 ,
166+ "TrainingArguments.per_device_train_batch_size" : train_batch_size ,
167+ "TrainingArguments.per_device_eval_batch_size" : train_batch_size * 2 ,
168+ "TrainingArguments.num_train_epochs" : 100 ,
169+ "TrainingArguments.model_name_or_path" : english_models ,
170+ "TrainingArguments.learning_rate" : 3e-5 ,
171+ },
145172 # slow learning: small LR, large early stop patience
146173 {
147174 "preset" : "finetune" ,
@@ -154,6 +181,17 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
154181 "TrainingArguments.model_name_or_path" : chinese_models ,
155182 "TrainingArguments.learning_rate" : 5e-6 ,
156183 },
184+ {
185+ "preset" : "finetune" ,
186+ "language" : "English" ,
187+ "trainer_type" : "Trainer" ,
188+ "EarlyStoppingCallback.early_stopping_patience" : 5 ,
189+ "TrainingArguments.per_device_train_batch_size" : train_batch_size ,
190+ "TrainingArguments.per_device_eval_batch_size" : train_batch_size * 2 ,
191+ "TrainingArguments.num_train_epochs" : 100 ,
192+ "TrainingArguments.model_name_or_path" : english_models ,
193+ "TrainingArguments.learning_rate" : 5e-6 ,
194+ },
157195 # Note: prompt tuning candidates not included for now due to lack of inference capability
158196 ]
159197
0 commit comments