Skip to content

Commit ff86643

Browse files
committed
separate class balance errors
1 parent 9d3784d commit ff86643

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

snorkel/labeling/model/label_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -539,22 +539,25 @@ def _set_class_balance(
539539
"""
540540
if class_balance is not None:
541541
self.p = np.array(class_balance)
542+
if len(self.p) != self.cardinality:
543+
raise ValueError(
544+
f"class_balance has {len(self.p)} entries. Does not match LabelModel cardinality {self.cardinality}."
545+
)
542546
elif Y_dev is not None:
543547
class_counts = Counter(Y_dev)
544548
sorted_counts = np.array([v for k, v in sorted(class_counts.items())])
545549
self.p = sorted_counts / sum(sorted_counts)
550+
if len(self.p) != self.cardinality:
551+
raise ValueError(
552+
f"Y_dev has {len(self.p)} class(es). Does not match LabelModel cardinality {self.cardinality}."
553+
)
546554
else:
547555
self.p = (1 / self.cardinality) * np.ones(self.cardinality)
548556

549557
if np.any(self.p == 0):
550558
raise ValueError(
551559
f"Class balance prior is 0 for class(es) {np.where(self.p)[0]}."
552560
)
553-
if len(self.p) != self.cardinality:
554-
raise ValueError(
555-
f"Y_dev has {len(self.p)} class(es). Does not match LabelModel cardinality {self.cardinality}."
556-
)
557-
558561
self.P = torch.diag(torch.from_numpy(self.p)).float()
559562

560563
def _set_constants(self, L: np.ndarray) -> None:

test/labeling/model/test_label_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,14 @@ def test_class_balance(self):
6969
with self.assertRaisesRegex(ValueError, "Class balance prior is 0"):
7070
label_model._set_class_balance(class_balance=class_balance, Y_dev=Y_dev)
7171

72+
class_balance = np.array([0.0])
73+
with self.assertRaisesRegex(ValueError, "class_balance has 1 entries."):
74+
label_model._set_class_balance(class_balance=class_balance, Y_dev=Y_dev)
75+
7276
Y_dev_one_class = np.array([0, 0, 0])
73-
with self.assertRaisesRegex(ValueError, "Does not match LabelModel cardinality"):
77+
with self.assertRaisesRegex(
78+
ValueError, "Does not match LabelModel cardinality"
79+
):
7480
label_model._set_class_balance(class_balance=None, Y_dev=Y_dev_one_class)
7581

7682
def test_generate_O(self):

0 commit comments

Comments
 (0)