@@ -19,15 +19,15 @@ def test_sce_equals_ce(self):
19
19
20
20
ce_loss = F .cross_entropy (Y_probs , Y_golds , reduction = "none" )
21
21
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs , reduction = "none" )
22
- np .testing .assert_equal (ce_loss .numpy (), ces_loss .numpy ())
22
+ np .testing .assert_almost_equal (ce_loss .numpy (), ces_loss .numpy (), decimal = 6 )
23
23
24
24
ce_loss = F .cross_entropy (Y_probs , Y_golds , reduction = "sum" )
25
25
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs , reduction = "sum" )
26
- np .testing .assert_equal (ce_loss .numpy (), ces_loss .numpy ())
26
+ np .testing .assert_almost_equal (ce_loss .numpy (), ces_loss .numpy (), decimal = 6 )
27
27
28
28
ce_loss = F .cross_entropy (Y_probs , Y_golds , reduction = "mean" )
29
29
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs , reduction = "mean" )
30
- np .testing .assert_equal (ce_loss .numpy (), ces_loss .numpy ())
30
+ np .testing .assert_almost_equal (ce_loss .numpy (), ces_loss .numpy (), decimal = 6 )
31
31
32
32
def test_perfect_predictions (self ):
33
33
# Does soft ce loss achieve approx. 0 loss with perfect predictions?
0 commit comments