File tree Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -78,7 +78,6 @@ def test_e2e():
78
78
max_docs = 12
79
79
80
80
fonduer .init_logging (
81
- log_dir = "log_folder" ,
82
81
format = "[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s" ,
83
82
level = logging .INFO ,
84
83
)
@@ -534,6 +533,21 @@ def test_e2e():
534
533
shuffle = True ,
535
534
)
536
535
536
+ valid_dataloader = EmmentalDataLoader (
537
+ task_to_label_dict = {ATTRIBUTE : "labels" },
538
+ dataset = FonduerDataset (
539
+ ATTRIBUTE ,
540
+ train_cands [0 ],
541
+ F_train [0 ],
542
+ emb_layer .word2id ,
543
+ np .argmax (train_marginals , axis = 1 ),
544
+ train_idxs ,
545
+ ),
546
+ split = "valid" ,
547
+ batch_size = 100 ,
548
+ shuffle = False ,
549
+ )
550
+
537
551
emmental .Meta .reset ()
538
552
emmental .init (fonduer .Meta .log_path )
539
553
emmental .Meta .update_config (config = config )
@@ -548,7 +562,7 @@ def test_e2e():
548
562
model .add_task (task )
549
563
550
564
emmental_learner = EmmentalLearner ()
551
- emmental_learner .learn (model , [train_dataloader ])
565
+ emmental_learner .learn (model , [train_dataloader , valid_dataloader ])
552
566
553
567
test_preds = model .predict (test_dataloader , return_preds = True )
554
568
positive = np .where (np .array (test_preds ["probs" ][ATTRIBUTE ])[:, TRUE ] > 0.7 )
You can’t perform that action at this time.
0 commit comments