Skip to content

Commit 3391df9

Browse files
senwulukehsiao
authored andcommitted
test(test_e2e): ensure that 0-index labels are in the valid set (#389)
1 parent 2b953c5 commit 3391df9

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

tests/e2e/test_e2e.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def test_e2e():
7878
max_docs = 12
7979

8080
fonduer.init_logging(
81-
log_dir="log_folder",
8281
format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
8382
level=logging.INFO,
8483
)
@@ -534,6 +533,21 @@ def test_e2e():
534533
shuffle=True,
535534
)
536535

536+
valid_dataloader = EmmentalDataLoader(
537+
task_to_label_dict={ATTRIBUTE: "labels"},
538+
dataset=FonduerDataset(
539+
ATTRIBUTE,
540+
train_cands[0],
541+
F_train[0],
542+
emb_layer.word2id,
543+
np.argmax(train_marginals, axis=1),
544+
train_idxs,
545+
),
546+
split="valid",
547+
batch_size=100,
548+
shuffle=False,
549+
)
550+
537551
emmental.Meta.reset()
538552
emmental.init(fonduer.Meta.log_path)
539553
emmental.Meta.update_config(config=config)
@@ -548,7 +562,7 @@ def test_e2e():
548562
model.add_task(task)
549563

550564
emmental_learner = EmmentalLearner()
551-
emmental_learner.learn(model, [train_dataloader])
565+
emmental_learner.learn(model, [train_dataloader, valid_dataloader])
552566

553567
test_preds = model.predict(test_dataloader, return_preds=True)
554568
positive = np.where(np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)

0 commit comments

Comments
 (0)