@@ -29,26 +29,56 @@ def _set_up_model(self, L: np.ndarray, class_balance: List[float] = [0.5, 0.5]):
29
29
30
30
def test_L_form (self ):
31
31
label_model = LabelModel (cardinality = 2 , verbose = False )
32
- L = np .array ([[0 , 1 , 0 ], [0 , 1 , 0 ], [1 , 0 , 0 ], [0 , 1 , 0 ]])
32
+ L = np .array ([[- 1 , 1 , - 1 ], [- 1 , 1 , - 1 ], [1 , - 1 , - 1 ], [- 1 , 1 , - 1 ]])
33
33
label_model ._set_constants (L )
34
34
self .assertEqual (label_model .n , 4 )
35
35
self .assertEqual (label_model .m , 3 )
36
36
37
- L = np .array ([[0 , 1 , 2 ], [0 , 1 , 2 ], [1 , 0 , 2 ], [0 , 1 , 0 ]])
37
+ L = np .array ([[- 1 , 0 , 1 ], [- 1 , 0 , 2 ], [0 , - 1 , 2 ], [- 1 , 0 , - 1 ]])
38
38
with self .assertRaisesRegex (ValueError , "L_train has cardinality" ):
39
39
label_model .fit (L , n_epochs = 1 )
40
40
41
- L = np .array ([[0 ], [1 ], [- 1 ]])
41
+ L = np .array ([[0 , 1 ], [1 , 1 ], [0 , 1 ]])
42
42
with self .assertRaisesRegex (ValueError , "L_train should have at least 3" ):
43
43
label_model .fit (L , n_epochs = 1 )
44
44
45
+ def test_mv_default (self ):
46
+ # less than 2 LFs have overlaps
47
+ label_model = LabelModel (cardinality = 2 , verbose = False )
48
+ L = np .array ([[- 1 , - 1 , 1 ], [- 1 , 1 , - 1 ], [0 , - 1 , - 1 ]])
49
+ label_model .fit (L , n_epochs = 100 )
50
+ np .testing .assert_array_almost_equal (
51
+ label_model .predict (L ), np .array ([1 , 1 , 0 ])
52
+ )
53
+
54
+ # less than 2 LFs have conflicts
55
+ L = np .array ([[- 1 , - 1 , 1 ], [- 1 , 1 , 1 ], [1 , 1 , 1 ]])
56
+ label_model .fit (L , n_epochs = 100 )
57
+ np .testing .assert_array_almost_equal (
58
+ label_model .predict (L ), np .array ([1 , 1 , 1 ])
59
+ )
60
+
45
61
def test_class_balance (self ):
46
62
label_model = LabelModel (cardinality = 2 , verbose = False )
47
63
# Test class balance
48
64
Y_dev = np .array ([0 , 0 , 1 , 1 , 0 , 0 , 0 , 0 , 1 , 1 ])
49
65
label_model ._set_class_balance (class_balance = None , Y_dev = Y_dev )
50
66
np .testing .assert_array_almost_equal (label_model .p , np .array ([0.6 , 0.4 ]))
51
67
68
+ class_balance = np .array ([0.0 , 1.0 ])
69
+ with self .assertRaisesRegex (ValueError , "Class balance prior is 0" ):
70
+ label_model ._set_class_balance (class_balance = class_balance , Y_dev = Y_dev )
71
+
72
+ class_balance = np .array ([0.0 ])
73
+ with self .assertRaisesRegex (ValueError , "class_balance has 1 entries." ):
74
+ label_model ._set_class_balance (class_balance = class_balance , Y_dev = Y_dev )
75
+
76
+ Y_dev_one_class = np .array ([0 , 0 , 0 ])
77
+ with self .assertRaisesRegex (
78
+ ValueError , "Does not match LabelModel cardinality"
79
+ ):
80
+ label_model ._set_class_balance (class_balance = None , Y_dev = Y_dev_one_class )
81
+
52
82
def test_generate_O (self ):
53
83
L = np .array ([[0 , 1 , 0 ], [0 , 1 , 0 ], [1 , 0 , 0 ], [0 , 1 , 1 ]])
54
84
label_model = self ._set_up_model (L )
0 commit comments