10
10
11
11
class SoftCrossEntropyTest (unittest .TestCase ):
12
12
def test_sce_equals_ce (self ):
13
+ DECIMAL = 6
13
14
# Does soft ce loss match classic ce loss when labels are one-hot?
14
15
Y_golds = torch .LongTensor ([0 , 1 , 2 ])
15
16
Y_golds_probs = torch .Tensor (preds_to_probs (Y_golds .numpy (), num_classes = 4 ))
@@ -19,15 +20,21 @@ def test_sce_equals_ce(self):
19
20
20
21
ce_loss = F .cross_entropy (Y_probs , Y_golds , reduction = "none" )
21
22
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs , reduction = "none" )
22
- np .testing .assert_equal (ce_loss .numpy (), ces_loss .numpy ())
23
+ np .testing .assert_almost_equal (
24
+ ce_loss .numpy (), ces_loss .numpy (), decimal = DECIMAL
25
+ )
23
26
24
27
ce_loss = F .cross_entropy (Y_probs , Y_golds , reduction = "sum" )
25
28
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs , reduction = "sum" )
26
- np .testing .assert_equal (ce_loss .numpy (), ces_loss .numpy ())
29
+ np .testing .assert_almost_equal (
30
+ ce_loss .numpy (), ces_loss .numpy (), decimal = DECIMAL
31
+ )
27
32
28
33
ce_loss = F .cross_entropy (Y_probs , Y_golds , reduction = "mean" )
29
34
ces_loss = cross_entropy_with_probs (Y_probs , Y_golds_probs , reduction = "mean" )
30
- np .testing .assert_almost_equal (ce_loss .numpy (), ces_loss .numpy ())
35
+ np .testing .assert_almost_equal (
36
+ ce_loss .numpy (), ces_loss .numpy (), decimal = DECIMAL
37
+ )
31
38
32
39
def test_perfect_predictions (self ):
33
40
# Does soft ce loss achieve approx. 0 loss with perfect predictions?
0 commit comments