Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 129 additions & 44 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections import Counter
from itertools import chain
from itertools import chain, permutations
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -283,52 +283,57 @@ def _init_params(self) -> None:
# Build the mask over O^{-1}
self._build_mask()

def _get_conditional_probs(self, source: Optional[int] = None) -> np.ndarray:
r"""Return the full conditional probabilities table.
def _get_conditional_probs(self, mu: np.ndarray) -> np.ndarray:
r"""Return the estimated conditional probabilities table given parameters mu.

In cond. prob. table, row i*(k+1) + ly is the conditional probabilities of source i
emmiting label ly (including abstains 0), conditioned on different
values of Y, i.e.:
Given a parameter vector mu, return the estimated conditional probabilites
table cprobs, where cprobs is an (m, k+1, k)-dim np.ndarray with:

c_probs[i*(k+1) + ly, y] = P(\lambda_i = ly | Y = y)
cprobs[i, j, k] = P(\lf_i = j-1 | Y = k)

Note that this simply involves inferring the kth row by law of total
probability and adding in to mu.

If ``source`` is not None, returns only the corresponding block.
where m is the number of LFs, k is the cardinality, and cprobs includes the
conditional abstain probabilities P(\lf_i = -1 | Y = y).

Parameters
----------
source
Index of source to generate conditional probabilities for, by default None
mu
An [m * k, k] np.ndarray with entries in [0, 1]

Returns
-------
np.ndarray
Conditional probabilities table if source is None, else corresponding block
An [m, k + 1, k] np.ndarray conditional probabilities table.
"""
c_probs = np.zeros((self.m * (self.cardinality + 1), self.cardinality))
mu = self.mu.detach().clone().numpy()

cprobs = np.zeros((self.m, self.cardinality + 1, self.cardinality))
for i in range(self.m):
# si = self.c_data[(i,)]['start_index']
# ei = self.c_data[(i,)]['end_index']
# mu_i = mu[si:ei, :]
mu_i = mu[i * self.cardinality : (i + 1) * self.cardinality, :]
c_probs[
i * (self.cardinality + 1) + 1 : (i + 1) * (self.cardinality + 1), :
] = mu_i
cprobs[i, 1:, :] = mu_i

# The 0th row (corresponding to abstains) is the difference between
# the sums of the other rows and one, by law of total prob
c_probs[i * (self.cardinality + 1), :] = 1 - mu_i.sum(axis=0)
# the sums of the other rows and one, by law of total probability
cprobs[i, 0, :] = 1 - mu_i.sum(axis=0)
return cprobs

if source is not None:
return c_probs[
source * (self.cardinality + 1) : (source + 1) * (self.cardinality + 1)
]
else:
return c_probs
def get_conditional_probs(self) -> np.ndarray:
r"""Return the estimated conditional probabilities table.

Return the estimated conditional probabilites table cprobs, where cprobs is an
(m, k+1, k)-dim np.ndarray with:

cprobs[i, j, k] = P(\lf_i = j-1 | Y = k)

where m is the number of LFs, k is the cardinality, and cprobs includes the
conditional abstain probabilities P(\lf_i = -1 | Y = y).

Returns
-------
np.ndarray
An [m, k + 1, k] np.ndarray conditional probabilities table.
"""
return self._get_conditional_probs(self.mu.detach().numpy())

def get_weights(self) -> np.ndarray:
"""Return the vector of learned LF weights for combining LFs.
Expand All @@ -347,10 +352,9 @@ def get_weights(self) -> np.ndarray:
array([0.99, 0.99, 0.99])
"""
accs = np.zeros(self.m)
cprobs = self.get_conditional_probs()
for i in range(self.m):
cps = self._get_conditional_probs(source=i)[1:, :]
accs[i] = np.diag(cps @ self.P.numpy()).sum()

accs[i] = np.diag(cprobs[i, 1:, :] @ self.P.numpy()).sum()
return np.clip(accs / self.coverage, 1e-6, 1.0)

def predict_proba(self, L: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -379,7 +383,7 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray:
L_shift = L + 1 # convert to {0, 1, ..., k}
self._set_constants(L_shift)
L_aug = self._get_augmented_label_matrix(L_shift)
mu = self.mu.detach().clone().numpy()
mu = self.mu.detach().numpy()
jtm = np.ones(L_aug.shape[1])

# Note: We omit abstains, effectively assuming uniform distribution here
Expand Down Expand Up @@ -693,6 +697,96 @@ def _update_lr_scheduler(self, step: int) -> None:
if min_lr and self.optimizer.param_groups[0]["lr"] < min_lr:
self.optimizer.param_groups[0]["lr"] = min_lr

def _clamp_params(self) -> None:
"""Clamp the values of the learned parameter vector.

Clamp the entries of self.mu to be in [mu_eps, 1 - mu_eps], where mu_eps is
either set by the user, or defaults to 1 / 10 ** np.ceil(np.log10(self.n)).

Note that if mu_eps is set too high, e.g. in sparse settings where LFs
mostly abstain, this will result in learning conditional probabilities all
equal to mu_eps (and/or 1 - mu_eps)! See issue #1422.

Note: Use user-provided value of mu_eps in train_config, else default to
mu_eps = 1 / 10 ** np.ceil(np.log10(self.n))
this rounding is done to make it more obvious when the parameters have been
clamped.
"""
if self.train_config.mu_eps is not None:
mu_eps = self.train_config.mu_eps
else:
mu_eps = min(0.01, 1 / 10 ** np.ceil(np.log10(self.n)))
self.mu.data = self.mu.clamp(mu_eps, 1 - mu_eps) # type: ignore

def _count_accurate_lfs(self, mu: np.ndarray) -> int:
r"""Count the number of LFs that are estimated to be better than random.

Return the number of LFs are estimated to be more accurate than not when not
abstaining, i.e., where

P(\lf = Y) > P(\lf != Y, \lf != -1).

Parameters
----------
mu
An [m * k, k] np.ndarray with entries in [0, 1]

Returns
-------
int
Number of LFs better than random
"""
P = self.P.numpy()
cprobs = self._get_conditional_probs(mu)
count = 0
for i in range(self.m):
probs = cprobs[i, 1:] @ P
if 2 * np.diagonal(probs).sum() - probs.sum() > 0:
count += 1
return count

def _break_col_permutation_symmetry(self) -> None:
r"""Heuristically choose amongst (possibly) several valid mu values.

If there are several values of mu that equivalently satisfy the optimization
objective, as there often are due to column permutation symmetries, then pick
the solution that trusts the user-written LFs most.

In more detail, suppose that mu satisfies (minimizes) the two loss objectives:
1. O = mu @ P @ mu.T
2. diag(O) = sum(mu @ P, axis=1)
Then any column permutation matrix Z that commutes with P will also equivalently
satisfy these objectives, and thus is an equally valid (symmetric) solution.
Therefore, we select the solution where the most LFs are estimated to be more
accurate than not when not abstaining, i.e., where for the majority of LFs,

P(\lf = Y) > P(\lf != Y, \lf != -1).

This is the standard assumption we have made in algorithmic and theoretical
work to date. Note however that this is not the only possible heuristic /
assumption that we could use, and in practice this may require further
iteration here.
"""
mu = self.mu.detach().numpy()
P = self.P.numpy()
d, k = mu.shape

# Iterate through the possible perumation matrices and track heuristic scores
Zs = []
scores = []
for idxs in permutations(range(k)):
Z = np.eye(k)[:, idxs]
Zs.append(Z)

# If Z and P commute, get heuristic score, else skip
if np.allclose(Z @ P, P @ Z):
scores.append(self._count_accurate_lfs(mu @ Z))
else:
scores.append(-1)

# Set mu according to highest-scoring permutation
self.mu.data = torch.Tensor(mu @ Zs[np.argmax(scores)]) # type: ignore

def fit(
self,
L_train: np.ndarray,
Expand Down Expand Up @@ -803,18 +897,9 @@ def fit(
# Update learning rate
self._update_lr_scheduler(epoch)

# Clamp learned parameters
# Note: If mu_eps is set too high, e.g. in sparse settings where LFs
# mostly abstain, this will result in learning conditional probabilities all
# equal to mu_eps (and/or 1 - mu_eps)!
# Note: Use user-provided value, else default to 1 / n', where n' is n rounded
# to the closest power of ten; this rounding is done to make it more obvious
# when the parameters have been clamped.
if self.train_config.mu_eps is not None:
mu_eps = self.train_config.mu_eps
else:
mu_eps = min(0.01, 1 / 10 ** np.ceil(np.log10(self.n)))
self.mu.data = self.mu.clamp(mu_eps, 1 - mu_eps) # type: ignore
# Post-processing operations on mu
self._clamp_params()
self._break_col_permutation_symmetry()

# Return model to eval mode
self.eval()
Expand Down
78 changes: 68 additions & 10 deletions test/labeling/model/test_label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.optim as optim

Expand Down Expand Up @@ -124,9 +125,9 @@ def test_augmented_L_construction(self):
def test_conditional_probs(self):
L = np.array([[0, 1, 0], [0, 1, 0]])
label_model = self._set_up_model(L, class_balance=[0.6, 0.4])
probs = label_model._get_conditional_probs()
self.assertLessEqual(probs.max(), 1.0)
self.assertGreaterEqual(probs.min(), 0.0)
cprobs = label_model.get_conditional_probs()
self.assertLessEqual(cprobs.max(), 1.0)
self.assertGreaterEqual(cprobs.min(), 0.0)

def test_get_weight(self):
# set up L matrix
Expand Down Expand Up @@ -352,7 +353,68 @@ def test_set_mu_eps(self):
L = np.array([[1, 1, 1], [1, 1, 1]])
label_model = LabelModel(verbose=False)
label_model.fit(L, mu_eps=mu_eps)
self.assertAlmostEqual(label_model._get_conditional_probs(0)[1, 0], mu_eps)
self.assertAlmostEqual(label_model.get_conditional_probs()[0, 1, 0], mu_eps)

def test_count_accurate_lfs(self):
mu = np.array(
[
# LF 0
[0.75, 0.25],
[0.25, 0.75],
# LF 1
[0.25, 0.75],
[0.15, 0.25],
# LF 2
[0.75, 0.25],
[0.25, 0.75],
]
)

# First test: Two "good" LFs
label_model = LabelModel(verbose=False)
label_model._set_class_balance(None, None)
label_model.m = 3
self.assertEqual(label_model._count_accurate_lfs(mu), 2)

# Second test: Now they should all be "good" due to class balance, since we're
# counting accuracy (not conditional probabilities)
label_model = LabelModel(verbose=False)
label_model._set_class_balance([0.9, 0.1], None)
label_model.m = 3
self.assertEqual(label_model._count_accurate_lfs(mu), 3)

def test_symmetry_breaking(self):
mu = np.array(
[
# LF 0
[0.75, 0.25],
[0.25, 0.75],
# LF 1
[0.25, 0.75],
[0.15, 0.25],
# LF 2
[0.75, 0.25],
[0.25, 0.75],
]
)
mu = mu[:, [1, 0]]

# First test: Two "good" LFs
label_model = LabelModel(verbose=False)
label_model._set_class_balance(None, None)
label_model.m = 3
label_model.mu = nn.Parameter(torch.from_numpy(mu))
label_model._break_col_permutation_symmetry()
self.assertEqual(label_model.mu.data[0, 0], 0.75)

# Test with non-uniform class balance
# It should not consider the "correct" permutation as does not commute now
label_model = LabelModel(verbose=False)
label_model._set_class_balance([0.9, 0.1], None)
label_model.m = 3
label_model.mu = nn.Parameter(torch.from_numpy(mu))
label_model._break_col_permutation_symmetry()
self.assertEqual(label_model.mu.data[0, 0], 0.25)


@pytest.mark.complex
Expand All @@ -375,9 +437,7 @@ def test_label_model_basic(self) -> None:
label_model.fit(L, n_epochs=200, lr=0.01, seed=123)

# Test estimated LF conditional probabilities
P_lm = label_model._get_conditional_probs().reshape(
(self.m, self.cardinality + 1, -1)
)
P_lm = label_model.get_conditional_probs()
np.testing.assert_array_almost_equal(P, P_lm, decimal=2)

# Test predicted labels
Expand All @@ -401,9 +461,7 @@ def test_label_model_sparse(self) -> None:
label_model.fit(L, n_epochs=1000, lr=0.01, seed=123)

# Test estimated LF conditional probabilities
P_lm = label_model._get_conditional_probs().reshape(
(self.m, self.cardinality + 1, -1)
)
P_lm = label_model.get_conditional_probs()
np.testing.assert_array_almost_equal(P, P_lm, decimal=2)

# Test predicted labels *only on non-abstained data points*
Expand Down