Skip to content

Commit f9a3df3

Browse files
author
ryan.smith
committed
Fix failing tests
1 parent d21fb1f commit f9a3df3

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

test/classification/test_classifier_convergence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_convergence(self):
5151
model = MultitaskClassifier(tasks=[task1, task2])
5252

5353
# Train
54-
trainer = Trainer(lr=0.001, n_epochs=10, progress_bar=False)
54+
trainer = Trainer(lr=0.002, n_epochs=50, progress_bar=False)
5555
trainer.fit(model, dataloaders)
5656
scores = model.score(dataloaders)
5757

test/classification/test_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ def test_sce_equals_ce(self):
1919

2020
ce_loss = F.cross_entropy(Y_probs, Y_golds, reduction="none")
2121
ces_loss = cross_entropy_with_probs(Y_probs, Y_golds_probs, reduction="none")
22-
np.testing.assert_equal(ce_loss.numpy(), ces_loss.numpy())
22+
np.testing.assert_almost_equal(ce_loss.numpy(), ces_loss.numpy())
2323

2424
ce_loss = F.cross_entropy(Y_probs, Y_golds, reduction="sum")
2525
ces_loss = cross_entropy_with_probs(Y_probs, Y_golds_probs, reduction="sum")
26-
np.testing.assert_equal(ce_loss.numpy(), ces_loss.numpy())
26+
np.testing.assert_almost_equal(ce_loss.numpy(), ces_loss.numpy())
2727

2828
ce_loss = F.cross_entropy(Y_probs, Y_golds, reduction="mean")
2929
ces_loss = cross_entropy_with_probs(Y_probs, Y_golds_probs, reduction="mean")
30-
np.testing.assert_equal(ce_loss.numpy(), ces_loss.numpy())
30+
np.testing.assert_almost_equal(ce_loss.numpy(), ces_loss.numpy())
3131

3232
def test_perfect_predictions(self):
3333
# Does soft ce loss achieve approx. 0 loss with perfect predictions?

0 commit comments

Comments
 (0)