Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions snorkel/labeling/model/base_labeler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import logging
import pickle
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

from snorkel.analysis import Scorer
from snorkel.utils import probs_to_preds


class BaseLabeler(ABC):
"""Abstract baseline label voter class."""

def __init__(self, cardinality: int = 2, **kwargs: Any) -> None:
self.cardinality = cardinality

@abstractmethod
def predict_proba(self, L: np.ndarray) -> np.ndarray:
"""Abstract method for predicting probabilistic labels given a label matrix.

Parameters
----------
L
An [n,m] matrix with values in {-1,0,1,...,k-1}f

Returns
-------
np.ndarray
An [n,k] array of probabilistic labels
"""
pass

def predict(
self,
L: np.ndarray,
return_probs: Optional[bool] = False,
tie_break_policy: str = "abstain",
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Return predicted labels, with ties broken according to policy.

Policies to break ties include:
"abstain": return an abstain vote (-1)
"true-random": randomly choose among the tied options
"random": randomly choose among tied option using deterministic hash

NOTE: if tie_break_policy="true-random", repeated runs may have slightly different
results due to difference in broken ties


Parameters
----------
L
An [n,m] matrix with values in {-1,0,1,...,k-1}
return_probs
Whether to return probs along with preds
tie_break_policy
Policy to break ties when converting probabilistic labels to predictions

Returns
-------
np.ndarray
An [n,1] array of integer labels

(np.ndarray, np.ndarray)
An [n,1] array of integer labels and an [n,k] array of probabilistic labels
"""
Y_probs = self.predict_proba(L)
Y_p = probs_to_preds(Y_probs, tie_break_policy)
if return_probs:
return Y_p, Y_probs
return Y_p

def score(
self,
L: np.ndarray,
Y: np.ndarray,
metrics: Optional[List[str]] = ["accuracy"],
tie_break_policy: str = "abstain",
) -> Dict[str, float]:
"""Calculate one or more scores from user-specified and/or user-defined metrics.

Parameters
----------
L
An [n,m] matrix with values in {-1,0,1,...,k-1}
Y
Gold labels associated with data points in L
metrics
A list of metric names
tie_break_policy
Policy to break ties when converting probabilistic labels to predictions


Returns
-------
Dict[str, float]
A dictionary mapping metric names to metric scores
"""
if tie_break_policy == "abstain": # pragma: no cover
logging.warning(
"Metrics calculated over data points with non-abstain labels only"
)

Y_pred, Y_prob = self.predict(
L, return_probs=True, tie_break_policy=tie_break_policy
)

scorer = Scorer(metrics=metrics)
results = scorer.score(Y, Y_pred, Y_prob)
return results

def save(self, destination: str) -> None:
"""Save label model.

Parameters
----------
destination
Filename for saving model

Example
-------
>>> label_model.save('./saved_label_model.pkl') # doctest: +SKIP
"""
f = open(destination, "wb")
pickle.dump(self.__dict__, f)
f.close()

def load(self, source: str) -> None:
"""Load existing label model.

Parameters
----------
source
Filename to load model from

Example
-------
Load parameters saved in ``saved_label_model``

>>> label_model.load('./saved_label_model.pkl') # doctest: +SKIP
"""
f = open(source, "rb")
tmp_dict = pickle.load(f)
f.close()
self.__dict__.update(tmp_dict)
24 changes: 4 additions & 20 deletions snorkel/labeling/model/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,10 @@

import numpy as np

from snorkel.labeling.model.label_model import LabelModel
from snorkel.labeling.model.base_labeler import BaseLabeler


class BaselineVoter(LabelModel):
"""Parent baseline label model class with method fit()."""

def fit(self, *args: Any, **kwargs: Any) -> None:
"""Train majority class model.

Set class balance for majority class label model.

Parameters
----------
balance
A [k] array of class probabilities
"""
pass


class RandomVoter(BaselineVoter):
class RandomVoter(BaseLabeler):
"""Random vote label model.

Example
Expand Down Expand Up @@ -57,7 +41,7 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray:
return Y_p


class MajorityClassVoter(LabelModel):
class MajorityClassVoter(BaseLabeler):
"""Majority class label model."""

def fit( # type: ignore
Expand Down Expand Up @@ -110,7 +94,7 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray:
return Y_p


class MajorityLabelVoter(BaselineVoter):
class MajorityLabelVoter(BaseLabeler):
"""Majority vote label model."""

def predict_proba(self, L: np.ndarray) -> np.ndarray:
Expand Down
60 changes: 4 additions & 56 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import pickle
import random
from collections import Counter, defaultdict
from itertools import chain
Expand All @@ -11,12 +10,11 @@
import torch.optim as optim
from munkres import Munkres # type: ignore

from snorkel.analysis import Scorer
from snorkel.labeling.analysis import LFAnalysis
from snorkel.labeling.model.base_labeler import BaseLabeler
from snorkel.labeling.model.graph_utils import get_clique_tree
from snorkel.labeling.model.logger import Logger
from snorkel.types import Config
from snorkel.utils import probs_to_preds
from snorkel.utils.config_utils import merge_config
from snorkel.utils.lr_schedulers import LRSchedulerConfig
from snorkel.utils.optimizers import OptimizerConfig
Expand Down Expand Up @@ -87,7 +85,7 @@ class _CliqueData(NamedTuple):
max_cliques: Set[int]


class LabelModel(nn.Module):
class LabelModel(nn.Module, BaseLabeler):
r"""A model for learning the LF accuracies and combining their output labels.

This class learns a model of the labeling functions' conditional probabilities
Expand Down Expand Up @@ -454,11 +452,7 @@ def predict(
>>> label_model.predict(L)
array([0, 1, 0])
"""
Y_probs = self.predict_proba(L)
Y_p = probs_to_preds(Y_probs, tie_break_policy)
if return_probs:
return Y_p, Y_probs
return Y_p
return super(LabelModel, self).predict(L, return_probs, tie_break_policy)

def score(
self,
Expand Down Expand Up @@ -496,18 +490,7 @@ def score(
>>> label_model.score(L, Y=np.array([1, 1, 1]), metrics=["f1"])
{'f1': 0.8}
"""
if tie_break_policy == "abstain": # pragma: no cover
logging.warning(
"Metrics calculated over data points with non-abstain labels only"
)

Y_pred, Y_prob = self.predict(
L, return_probs=True, tie_break_policy=tie_break_policy
)

scorer = Scorer(metrics=metrics)
results = scorer.score(Y, Y_pred, Y_prob)
return results
return super(LabelModel, self).score(L, Y, metrics, tie_break_policy)

# These loss functions get all their data directly from the LabelModel
# (for better or worse). The unused *args make these compatible with the
Expand Down Expand Up @@ -928,38 +911,3 @@ def fit(
# Print confusion matrix if applicable
if self.config.verbose: # pragma: no cover
logging.info("Finished Training")

def save(self, destination: str) -> None:
"""Save label model.

Parameters
----------
destination
Filename for saving model

Example
-------
>>> label_model.save('./saved_label_model.pkl') # doctest: +SKIP
"""
f = open(destination, "wb")
pickle.dump(self.__dict__, f)
f.close()

def load(self, source: str) -> None:
"""Load existing label model.

Parameters
----------
source
Filename to load model from

Example
-------
Load parameters saved in ``saved_label_model``

>>> label_model.load('./saved_label_model.pkl') # doctest: +SKIP
"""
f = open(source, "rb")
tmp_dict = pickle.load(f)
f.close()
self.__dict__.update(tmp_dict)