@@ -269,15 +269,21 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
269269 )
270270
271271 # test predict
272- copy_test_ds = copy .deepcopy (self .multi_label_dev_ds )
273- test_output = auto_trainer .predict (test_dataset = copy_test_ds )
272+ dev_output = auto_trainer .predict (test_dataset = copy_dev_ds )
274273 self .assertEqual (
275274 eval_metrics1 [auto_trainer .metric_for_best_model ],
276- test_output .metrics [auto_trainer .metric_for_best_model .replace ("eval" , "test" )],
275+ dev_output .metrics [auto_trainer .metric_for_best_model .replace ("eval" , "test" )],
277276 )
277+ self .assertEqual (len (copy_dev_ds ), len (dev_output .label_ids ))
278+ self .assertEqual (len (copy_dev_ds ), len (dev_output .predictions ))
279+ self .assertEqual (len (auto_trainer .id2label ), len (dev_output .predictions [0 ]))
280+
281+ copy_test_ds = copy .deepcopy (self .test_ds )
282+ test_output = auto_trainer .predict (test_dataset = copy_test_ds )
283+ self .assertFalse (auto_trainer .metric_for_best_model .replace ("eval" , "test" ) in test_output .metrics )
284+ self .assertEqual (None , test_output .label_ids )
278285 self .assertEqual (len (copy_test_ds ), len (test_output .predictions ))
279286 self .assertEqual (len (auto_trainer .id2label ), len (test_output .predictions [0 ]))
280- self .assertEqual (len (copy_test_ds ), len (test_output .label_ids ))
281287
282288 # test taskflow
283289 taskflow = auto_trainer .to_taskflow ()
@@ -371,12 +377,21 @@ def test_default_model_candidate(self, language, hp_overrides):
371377 )
372378
373379 # test predict
374- copy_test_ds = copy .deepcopy (self .multi_class_dev_ds )
375- eval_metrics3 = auto_trainer .predict (test_dataset = copy_test_ds ).metrics
380+ dev_output = auto_trainer .predict (test_dataset = copy_dev_ds )
376381 self .assertEqual (
377382 eval_metrics1 [auto_trainer .metric_for_best_model ],
378- eval_metrics3 [auto_trainer .metric_for_best_model .replace ("eval" , "test" )],
383+ dev_output . metrics [auto_trainer .metric_for_best_model .replace ("eval" , "test" )],
379384 )
385+ self .assertEqual (len (copy_dev_ds ), len (dev_output .label_ids ))
386+ self .assertEqual (len (copy_dev_ds ), len (dev_output .predictions ))
387+ self .assertEqual (len (auto_trainer .id2label ), len (dev_output .predictions [0 ]))
388+
389+ copy_test_ds = copy .deepcopy (self .test_ds )
390+ test_output = auto_trainer .predict (test_dataset = copy_test_ds )
391+ self .assertFalse (auto_trainer .metric_for_best_model .replace ("eval" , "test" ) in test_output .metrics )
392+ self .assertEqual (None , test_output .label_ids )
393+ self .assertEqual (len (copy_test_ds ), len (test_output .predictions ))
394+ self .assertEqual (len (auto_trainer .id2label ), len (test_output .predictions [0 ]))
380395
381396 # test export
382397 temp_export_path = os .path .join (temp_dir_path , "test_export" )
0 commit comments