File tree Expand file tree Collapse file tree 2 files changed +15
-6
lines changed Expand file tree Collapse file tree 2 files changed +15
-6
lines changed Original file line number Diff line number Diff line change @@ -539,22 +539,25 @@ def _set_class_balance(
539
539
"""
540
540
if class_balance is not None :
541
541
self .p = np .array (class_balance )
542
+ if len (self .p ) != self .cardinality :
543
+ raise ValueError (
544
+ f"class_balance has { len (self .p )} entries. Does not match LabelModel cardinality { self .cardinality } ."
545
+ )
542
546
elif Y_dev is not None :
543
547
class_counts = Counter (Y_dev )
544
548
sorted_counts = np .array ([v for k , v in sorted (class_counts .items ())])
545
549
self .p = sorted_counts / sum (sorted_counts )
550
+ if len (self .p ) != self .cardinality :
551
+ raise ValueError (
552
+ f"Y_dev has { len (self .p )} class(es). Does not match LabelModel cardinality { self .cardinality } ."
553
+ )
546
554
else :
547
555
self .p = (1 / self .cardinality ) * np .ones (self .cardinality )
548
556
549
557
if np .any (self .p == 0 ):
550
558
raise ValueError (
551
559
f"Class balance prior is 0 for class(es) { np .where (self .p )[0 ]} ."
552
560
)
553
- if len (self .p ) != self .cardinality :
554
- raise ValueError (
555
- f"Y_dev has { len (self .p )} class(es). Does not match LabelModel cardinality { self .cardinality } ."
556
- )
557
-
558
561
self .P = torch .diag (torch .from_numpy (self .p )).float ()
559
562
560
563
def _set_constants (self , L : np .ndarray ) -> None :
Original file line number Diff line number Diff line change @@ -69,8 +69,14 @@ def test_class_balance(self):
69
69
with self .assertRaisesRegex (ValueError , "Class balance prior is 0" ):
70
70
label_model ._set_class_balance (class_balance = class_balance , Y_dev = Y_dev )
71
71
72
+ class_balance = np .array ([0.0 ])
73
+ with self .assertRaisesRegex (ValueError , "class_balance has 1 entries." ):
74
+ label_model ._set_class_balance (class_balance = class_balance , Y_dev = Y_dev )
75
+
72
76
Y_dev_one_class = np .array ([0 , 0 , 0 ])
73
- with self .assertRaisesRegex (ValueError , "Does not match LabelModel cardinality" ):
77
+ with self .assertRaisesRegex (
78
+ ValueError , "Does not match LabelModel cardinality"
79
+ ):
74
80
label_model ._set_class_balance (class_balance = None , Y_dev = Y_dev_one_class )
75
81
76
82
def test_generate_O (self ):
You can’t perform that action at this time.
0 commit comments