Skip to content

Commit 5fb2e23

Browse files
fpomsAndreas Kodewitz
authored andcommitted
Fix test loss (snorkel-team#1694)
1 parent 7e8a018 commit 5fb2e23

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

test/classification/test_loss.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class SoftCrossEntropyTest(unittest.TestCase):
1212
def test_sce_equals_ce(self):
13+
DECIMAL = 6
1314
# Does soft ce loss match classic ce loss when labels are one-hot?
1415
Y_golds = torch.LongTensor([0, 1, 2])
1516
Y_golds_probs = torch.Tensor(preds_to_probs(Y_golds.numpy(), num_classes=4))
@@ -19,15 +20,21 @@ def test_sce_equals_ce(self):
1920

2021
ce_loss = F.cross_entropy(Y_probs, Y_golds, reduction="none")
2122
ces_loss = cross_entropy_with_probs(Y_probs, Y_golds_probs, reduction="none")
22-
np.testing.assert_equal(ce_loss.numpy(), ces_loss.numpy())
23+
np.testing.assert_almost_equal(
24+
ce_loss.numpy(), ces_loss.numpy(), decimal=DECIMAL
25+
)
2326

2427
ce_loss = F.cross_entropy(Y_probs, Y_golds, reduction="sum")
2528
ces_loss = cross_entropy_with_probs(Y_probs, Y_golds_probs, reduction="sum")
26-
np.testing.assert_equal(ce_loss.numpy(), ces_loss.numpy())
29+
np.testing.assert_almost_equal(
30+
ce_loss.numpy(), ces_loss.numpy(), decimal=DECIMAL
31+
)
2732

2833
ce_loss = F.cross_entropy(Y_probs, Y_golds, reduction="mean")
2934
ces_loss = cross_entropy_with_probs(Y_probs, Y_golds_probs, reduction="mean")
30-
np.testing.assert_almost_equal(ce_loss.numpy(), ces_loss.numpy())
35+
np.testing.assert_almost_equal(
36+
ce_loss.numpy(), ces_loss.numpy(), decimal=DECIMAL
37+
)
3138

3239
def test_perfect_predictions(self):
3340
# Does soft ce loss achieve approx. 0 loss with perfect predictions?

0 commit comments

Comments
 (0)