@@ -19,15 +19,15 @@ def test_sce_equals_ce(self):
19
19
20
20
ce_loss = F .cross_entropy (Y_probs , Y_golds , reduction = "none" )
21
21
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs , reduction = "none" )
22
- np .testing .assert_equal (ce_loss .numpy (), ces_loss .numpy ())
22
+ np .testing .assert_almost_equal (ce_loss .numpy (), ces_loss .numpy ())
23
23
24
24
ce_loss = F .cross_entropy (Y_probs , Y_golds , reduction = "sum" )
25
25
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs , reduction = "sum" )
26
- np .testing .assert_equal (ce_loss .numpy (), ces_loss .numpy ())
26
+ np .testing .assert_almost_equal (ce_loss .numpy (), ces_loss .numpy ())
27
27
28
28
ce_loss = F .cross_entropy (Y_probs , Y_golds , reduction = "mean" )
29
29
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs , reduction = "mean" )
30
- np .testing .assert_equal (ce_loss .numpy (), ces_loss .numpy ())
30
+ np .testing .assert_almost_equal (ce_loss .numpy (), ces_loss .numpy ())
31
31
32
32
def test_perfect_predictions (self ):
33
33
# Does soft ce loss achieve approx. 0 loss with perfect predictions?
@@ -39,7 +39,7 @@ def test_perfect_predictions(self):
39
39
Y_probs [Y_probs == 0 ] = - 100
40
40
41
41
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs )
42
- np .testing .assert_equal (ces_loss .numpy (), 0 )
42
+ np .testing .assert_almost_equal (ces_loss .numpy (), 0 )
43
43
44
44
def test_lower_loss (self ):
45
45
# Is loss lower when it should be?
0 commit comments