Skip to content

Commit f0da8a1

Browse files
committed
fix
1 parent ff1d3c4 commit f0da8a1

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

tests/experimental/autonlp/test_text_classification.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,21 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
269269
)
270270

271271
# test predict
272-
copy_test_ds = copy.deepcopy(self.multi_label_dev_ds)
273-
test_output = auto_trainer.predict(test_dataset=copy_test_ds)
272+
dev_output = auto_trainer.predict(test_dataset=copy_dev_ds)
274273
self.assertEqual(
275274
eval_metrics1[auto_trainer.metric_for_best_model],
276-
test_output.metrics[auto_trainer.metric_for_best_model.replace("eval", "test")],
275+
dev_output.metrics[auto_trainer.metric_for_best_model.replace("eval", "test")],
277276
)
277+
self.assertEqual(len(copy_dev_ds), len(dev_output.label_ids))
278+
self.assertEqual(len(copy_dev_ds), len(dev_output.predictions))
279+
self.assertEqual(len(auto_trainer.id2label), len(dev_output.predictions[0]))
280+
281+
copy_test_ds = copy.deepcopy(self.test_ds)
282+
test_output = auto_trainer.predict(test_dataset=copy_test_ds)
283+
self.assertFalse(auto_trainer.metric_for_best_model.replace("eval", "test") in test_output.metrics)
284+
self.assertEqual(None, test_output.label_ids)
278285
self.assertEqual(len(copy_test_ds), len(test_output.predictions))
279286
self.assertEqual(len(auto_trainer.id2label), len(test_output.predictions[0]))
280-
self.assertEqual(len(copy_test_ds), len(test_output.label_ids))
281287

282288
# test taskflow
283289
taskflow = auto_trainer.to_taskflow()
@@ -371,12 +377,21 @@ def test_default_model_candidate(self, language, hp_overrides):
371377
)
372378

373379
# test predict
374-
copy_test_ds = copy.deepcopy(self.multi_class_dev_ds)
375-
eval_metrics3 = auto_trainer.predict(test_dataset=copy_test_ds).metrics
380+
dev_output = auto_trainer.predict(test_dataset=copy_dev_ds)
376381
self.assertEqual(
377382
eval_metrics1[auto_trainer.metric_for_best_model],
378-
eval_metrics3[auto_trainer.metric_for_best_model.replace("eval", "test")],
383+
dev_output.metrics[auto_trainer.metric_for_best_model.replace("eval", "test")],
379384
)
385+
self.assertEqual(len(copy_dev_ds), len(dev_output.label_ids))
386+
self.assertEqual(len(copy_dev_ds), len(dev_output.predictions))
387+
self.assertEqual(len(auto_trainer.id2label), len(dev_output.predictions[0]))
388+
389+
copy_test_ds = copy.deepcopy(self.test_ds)
390+
test_output = auto_trainer.predict(test_dataset=copy_test_ds)
391+
self.assertFalse(auto_trainer.metric_for_best_model.replace("eval", "test") in test_output.metrics)
392+
self.assertEqual(None, test_output.label_ids)
393+
self.assertEqual(len(copy_test_ds), len(test_output.predictions))
394+
self.assertEqual(len(auto_trainer.id2label), len(test_output.predictions[0]))
380395

381396
# test export
382397
temp_export_path = os.path.join(temp_dir_path, "test_export")

0 commit comments

Comments
 (0)