Skip to content

Commit dd24000

Browse files
authored
Refactor BaseLabeler as parent class of label models (#1559)
* Move predict_proba to BaseLabeler class as abtractmethod * Move predict, score, save, and load methods to BaseLabeler class as shared methods in parent class * Update RandomVoter, MajorityClassVoter, MajorityLabelVoter, and LabelModel to be subclasses of BaseLabeler. * Update docstrings to reflect abstract/parent class hierarchy. * NOTE: We re-define methods in the LabelModel(BaseLabeler) subclass for predict and score in order to explicitly show docstring examples
1 parent faf8612 commit dd24000

File tree

3 files changed

+154
-76
lines changed

3 files changed

+154
-76
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import logging
2+
import pickle
3+
from abc import ABC, abstractmethod
4+
from typing import Any, Dict, List, Optional, Tuple, Union
5+
6+
import numpy as np
7+
8+
from snorkel.analysis import Scorer
9+
from snorkel.utils import probs_to_preds
10+
11+
12+
class BaseLabeler(ABC):
13+
"""Abstract baseline label voter class."""
14+
15+
def __init__(self, cardinality: int = 2, **kwargs: Any) -> None:
16+
self.cardinality = cardinality
17+
18+
@abstractmethod
19+
def predict_proba(self, L: np.ndarray) -> np.ndarray:
20+
"""Abstract method for predicting probabilistic labels given a label matrix.
21+
22+
Parameters
23+
----------
24+
L
25+
An [n,m] matrix with values in {-1,0,1,...,k-1}f
26+
27+
Returns
28+
-------
29+
np.ndarray
30+
An [n,k] array of probabilistic labels
31+
"""
32+
pass
33+
34+
def predict(
35+
self,
36+
L: np.ndarray,
37+
return_probs: Optional[bool] = False,
38+
tie_break_policy: str = "abstain",
39+
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
40+
"""Return predicted labels, with ties broken according to policy.
41+
42+
Policies to break ties include:
43+
"abstain": return an abstain vote (-1)
44+
"true-random": randomly choose among the tied options
45+
"random": randomly choose among tied option using deterministic hash
46+
47+
NOTE: if tie_break_policy="true-random", repeated runs may have slightly different
48+
results due to difference in broken ties
49+
50+
51+
Parameters
52+
----------
53+
L
54+
An [n,m] matrix with values in {-1,0,1,...,k-1}
55+
return_probs
56+
Whether to return probs along with preds
57+
tie_break_policy
58+
Policy to break ties when converting probabilistic labels to predictions
59+
60+
Returns
61+
-------
62+
np.ndarray
63+
An [n,1] array of integer labels
64+
65+
(np.ndarray, np.ndarray)
66+
An [n,1] array of integer labels and an [n,k] array of probabilistic labels
67+
"""
68+
Y_probs = self.predict_proba(L)
69+
Y_p = probs_to_preds(Y_probs, tie_break_policy)
70+
if return_probs:
71+
return Y_p, Y_probs
72+
return Y_p
73+
74+
def score(
75+
self,
76+
L: np.ndarray,
77+
Y: np.ndarray,
78+
metrics: Optional[List[str]] = ["accuracy"],
79+
tie_break_policy: str = "abstain",
80+
) -> Dict[str, float]:
81+
"""Calculate one or more scores from user-specified and/or user-defined metrics.
82+
83+
Parameters
84+
----------
85+
L
86+
An [n,m] matrix with values in {-1,0,1,...,k-1}
87+
Y
88+
Gold labels associated with data points in L
89+
metrics
90+
A list of metric names
91+
tie_break_policy
92+
Policy to break ties when converting probabilistic labels to predictions
93+
94+
95+
Returns
96+
-------
97+
Dict[str, float]
98+
A dictionary mapping metric names to metric scores
99+
"""
100+
if tie_break_policy == "abstain": # pragma: no cover
101+
logging.warning(
102+
"Metrics calculated over data points with non-abstain labels only"
103+
)
104+
105+
Y_pred, Y_prob = self.predict(
106+
L, return_probs=True, tie_break_policy=tie_break_policy
107+
)
108+
109+
scorer = Scorer(metrics=metrics)
110+
results = scorer.score(Y, Y_pred, Y_prob)
111+
return results
112+
113+
def save(self, destination: str) -> None:
114+
"""Save label model.
115+
116+
Parameters
117+
----------
118+
destination
119+
Filename for saving model
120+
121+
Example
122+
-------
123+
>>> label_model.save('./saved_label_model.pkl') # doctest: +SKIP
124+
"""
125+
f = open(destination, "wb")
126+
pickle.dump(self.__dict__, f)
127+
f.close()
128+
129+
def load(self, source: str) -> None:
130+
"""Load existing label model.
131+
132+
Parameters
133+
----------
134+
source
135+
Filename to load model from
136+
137+
Example
138+
-------
139+
Load parameters saved in ``saved_label_model``
140+
141+
>>> label_model.load('./saved_label_model.pkl') # doctest: +SKIP
142+
"""
143+
f = open(source, "rb")
144+
tmp_dict = pickle.load(f)
145+
f.close()
146+
self.__dict__.update(tmp_dict)

snorkel/labeling/model/baselines.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,10 @@
22

33
import numpy as np
44

5-
from snorkel.labeling.model.label_model import LabelModel
5+
from snorkel.labeling.model.base_labeler import BaseLabeler
66

77

8-
class BaselineVoter(LabelModel):
9-
"""Parent baseline label model class with method fit()."""
10-
11-
def fit(self, *args: Any, **kwargs: Any) -> None:
12-
"""Train majority class model.
13-
14-
Set class balance for majority class label model.
15-
16-
Parameters
17-
----------
18-
balance
19-
A [k] array of class probabilities
20-
"""
21-
pass
22-
23-
24-
class RandomVoter(BaselineVoter):
8+
class RandomVoter(BaseLabeler):
259
"""Random vote label model.
2610
2711
Example
@@ -57,7 +41,7 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray:
5741
return Y_p
5842

5943

60-
class MajorityClassVoter(LabelModel):
44+
class MajorityClassVoter(BaseLabeler):
6145
"""Majority class label model."""
6246

6347
def fit( # type: ignore
@@ -110,7 +94,7 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray:
11094
return Y_p
11195

11296

113-
class MajorityLabelVoter(BaselineVoter):
97+
class MajorityLabelVoter(BaseLabeler):
11498
"""Majority vote label model."""
11599

116100
def predict_proba(self, L: np.ndarray) -> np.ndarray:

snorkel/labeling/model/label_model.py

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import pickle
32
import random
43
from collections import Counter, defaultdict
54
from itertools import chain
@@ -11,12 +10,11 @@
1110
import torch.optim as optim
1211
from munkres import Munkres # type: ignore
1312

14-
from snorkel.analysis import Scorer
1513
from snorkel.labeling.analysis import LFAnalysis
14+
from snorkel.labeling.model.base_labeler import BaseLabeler
1615
from snorkel.labeling.model.graph_utils import get_clique_tree
1716
from snorkel.labeling.model.logger import Logger
1817
from snorkel.types import Config
19-
from snorkel.utils import probs_to_preds
2018
from snorkel.utils.config_utils import merge_config
2119
from snorkel.utils.lr_schedulers import LRSchedulerConfig
2220
from snorkel.utils.optimizers import OptimizerConfig
@@ -87,7 +85,7 @@ class _CliqueData(NamedTuple):
8785
max_cliques: Set[int]
8886

8987

90-
class LabelModel(nn.Module):
88+
class LabelModel(nn.Module, BaseLabeler):
9189
r"""A model for learning the LF accuracies and combining their output labels.
9290
9391
This class learns a model of the labeling functions' conditional probabilities
@@ -454,11 +452,7 @@ def predict(
454452
>>> label_model.predict(L)
455453
array([0, 1, 0])
456454
"""
457-
Y_probs = self.predict_proba(L)
458-
Y_p = probs_to_preds(Y_probs, tie_break_policy)
459-
if return_probs:
460-
return Y_p, Y_probs
461-
return Y_p
455+
return super(LabelModel, self).predict(L, return_probs, tie_break_policy)
462456

463457
def score(
464458
self,
@@ -496,18 +490,7 @@ def score(
496490
>>> label_model.score(L, Y=np.array([1, 1, 1]), metrics=["f1"])
497491
{'f1': 0.8}
498492
"""
499-
if tie_break_policy == "abstain": # pragma: no cover
500-
logging.warning(
501-
"Metrics calculated over data points with non-abstain labels only"
502-
)
503-
504-
Y_pred, Y_prob = self.predict(
505-
L, return_probs=True, tie_break_policy=tie_break_policy
506-
)
507-
508-
scorer = Scorer(metrics=metrics)
509-
results = scorer.score(Y, Y_pred, Y_prob)
510-
return results
493+
return super(LabelModel, self).score(L, Y, metrics, tie_break_policy)
511494

512495
# These loss functions get all their data directly from the LabelModel
513496
# (for better or worse). The unused *args make these compatible with the
@@ -928,38 +911,3 @@ def fit(
928911
# Print confusion matrix if applicable
929912
if self.config.verbose: # pragma: no cover
930913
logging.info("Finished Training")
931-
932-
def save(self, destination: str) -> None:
933-
"""Save label model.
934-
935-
Parameters
936-
----------
937-
destination
938-
Filename for saving model
939-
940-
Example
941-
-------
942-
>>> label_model.save('./saved_label_model.pkl') # doctest: +SKIP
943-
"""
944-
f = open(destination, "wb")
945-
pickle.dump(self.__dict__, f)
946-
f.close()
947-
948-
def load(self, source: str) -> None:
949-
"""Load existing label model.
950-
951-
Parameters
952-
----------
953-
source
954-
Filename to load model from
955-
956-
Example
957-
-------
958-
Load parameters saved in ``saved_label_model``
959-
960-
>>> label_model.load('./saved_label_model.pkl') # doctest: +SKIP
961-
"""
962-
f = open(source, "rb")
963-
tmp_dict = pickle.load(f)
964-
f.close()
965-
self.__dict__.update(tmp_dict)

0 commit comments

Comments
 (0)