Skip to content

Commit d6e79d1

Browse files
authored
Catch class balance errors and test L matrix edge cases (#1449)
1 parent e9fb792 commit d6e79d1

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

snorkel/labeling/model/label_model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,25 @@ def _set_class_balance(
539539
"""
540540
if class_balance is not None:
541541
self.p = np.array(class_balance)
542+
if len(self.p) != self.cardinality:
543+
raise ValueError(
544+
f"class_balance has {len(self.p)} entries. Does not match LabelModel cardinality {self.cardinality}."
545+
)
542546
elif Y_dev is not None:
543547
class_counts = Counter(Y_dev)
544548
sorted_counts = np.array([v for k, v in sorted(class_counts.items())])
545549
self.p = sorted_counts / sum(sorted_counts)
550+
if len(self.p) != self.cardinality:
551+
raise ValueError(
552+
f"Y_dev has {len(self.p)} class(es). Does not match LabelModel cardinality {self.cardinality}."
553+
)
546554
else:
547555
self.p = (1 / self.cardinality) * np.ones(self.cardinality)
556+
557+
if np.any(self.p == 0):
558+
raise ValueError(
559+
f"Class balance prior is 0 for class(es) {np.where(self.p)[0]}."
560+
)
548561
self.P = torch.diag(torch.from_numpy(self.p)).float()
549562

550563
def _set_constants(self, L: np.ndarray) -> None:
@@ -742,8 +755,8 @@ def fit(
742755
f"L_train has cardinality {L_shift.max()}, cardinality={self.cardinality} passed in."
743756
)
744757

745-
self._set_class_balance(class_balance, Y_dev)
746758
self._set_constants(L_shift)
759+
self._set_class_balance(class_balance, Y_dev)
747760
self._create_tree()
748761
lf_analysis = LFAnalysis(L_train)
749762
self.coverage = lf_analysis.lf_coverages()

test/labeling/model/test_label_model.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,56 @@ def _set_up_model(self, L: np.ndarray, class_balance: List[float] = [0.5, 0.5]):
2929

3030
def test_L_form(self):
3131
label_model = LabelModel(cardinality=2, verbose=False)
32-
L = np.array([[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0]])
32+
L = np.array([[-1, 1, -1], [-1, 1, -1], [1, -1, -1], [-1, 1, -1]])
3333
label_model._set_constants(L)
3434
self.assertEqual(label_model.n, 4)
3535
self.assertEqual(label_model.m, 3)
3636

37-
L = np.array([[0, 1, 2], [0, 1, 2], [1, 0, 2], [0, 1, 0]])
37+
L = np.array([[-1, 0, 1], [-1, 0, 2], [0, -1, 2], [-1, 0, -1]])
3838
with self.assertRaisesRegex(ValueError, "L_train has cardinality"):
3939
label_model.fit(L, n_epochs=1)
4040

41-
L = np.array([[0], [1], [-1]])
41+
L = np.array([[0, 1], [1, 1], [0, 1]])
4242
with self.assertRaisesRegex(ValueError, "L_train should have at least 3"):
4343
label_model.fit(L, n_epochs=1)
4444

45+
def test_mv_default(self):
46+
# less than 2 LFs have overlaps
47+
label_model = LabelModel(cardinality=2, verbose=False)
48+
L = np.array([[-1, -1, 1], [-1, 1, -1], [0, -1, -1]])
49+
label_model.fit(L, n_epochs=100)
50+
np.testing.assert_array_almost_equal(
51+
label_model.predict(L), np.array([1, 1, 0])
52+
)
53+
54+
# less than 2 LFs have conflicts
55+
L = np.array([[-1, -1, 1], [-1, 1, 1], [1, 1, 1]])
56+
label_model.fit(L, n_epochs=100)
57+
np.testing.assert_array_almost_equal(
58+
label_model.predict(L), np.array([1, 1, 1])
59+
)
60+
4561
def test_class_balance(self):
4662
label_model = LabelModel(cardinality=2, verbose=False)
4763
# Test class balance
4864
Y_dev = np.array([0, 0, 1, 1, 0, 0, 0, 0, 1, 1])
4965
label_model._set_class_balance(class_balance=None, Y_dev=Y_dev)
5066
np.testing.assert_array_almost_equal(label_model.p, np.array([0.6, 0.4]))
5167

68+
class_balance = np.array([0.0, 1.0])
69+
with self.assertRaisesRegex(ValueError, "Class balance prior is 0"):
70+
label_model._set_class_balance(class_balance=class_balance, Y_dev=Y_dev)
71+
72+
class_balance = np.array([0.0])
73+
with self.assertRaisesRegex(ValueError, "class_balance has 1 entries."):
74+
label_model._set_class_balance(class_balance=class_balance, Y_dev=Y_dev)
75+
76+
Y_dev_one_class = np.array([0, 0, 0])
77+
with self.assertRaisesRegex(
78+
ValueError, "Does not match LabelModel cardinality"
79+
):
80+
label_model._set_class_balance(class_balance=None, Y_dev=Y_dev_one_class)
81+
5282
def test_generate_O(self):
5383
L = np.array([[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 1, 1]])
5484
label_model = self._set_up_model(L)

0 commit comments

Comments
 (0)