Skip to content

Commit 8e4526e

Browse files
authored
Label model symmetry breaking (#1451)
* - Refactor get_conditional_probs - Factor out two post-processing ops on mu in LabelModel.fit - Implement heuristic symmetry breaking on mu * Fixing style check errors * Passing basic tests * Passes tox * Changed to the standard heuristic / assumption * Changed to proper test of accuracies vs. cond prob * Refactor subfn for counting good LFs + add test * Address PR comments * Add unit test for symmetry breaking * Address PR comments * Fix naming bug
1 parent a089eeb commit 8e4526e

File tree

2 files changed

+197
-54
lines changed

2 files changed

+197
-54
lines changed

snorkel/labeling/model/label_model.py

Lines changed: 129 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from collections import Counter
3-
from itertools import chain
3+
from itertools import chain, permutations
44
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union
55

66
import numpy as np
@@ -283,52 +283,57 @@ def _init_params(self) -> None:
283283
# Build the mask over O^{-1}
284284
self._build_mask()
285285

286-
def _get_conditional_probs(self, source: Optional[int] = None) -> np.ndarray:
287-
r"""Return the full conditional probabilities table.
286+
def _get_conditional_probs(self, mu: np.ndarray) -> np.ndarray:
287+
r"""Return the estimated conditional probabilities table given parameters mu.
288288
289-
In cond. prob. table, row i*(k+1) + ly is the conditional probabilities of source i
290-
emmiting label ly (including abstains 0), conditioned on different
291-
values of Y, i.e.:
289+
Given a parameter vector mu, return the estimated conditional probabilites
290+
table cprobs, where cprobs is an (m, k+1, k)-dim np.ndarray with:
292291
293-
c_probs[i*(k+1) + ly, y] = P(\lambda_i = ly | Y = y)
292+
cprobs[i, j, k] = P(\lf_i = j-1 | Y = k)
294293
295-
Note that this simply involves inferring the kth row by law of total
296-
probability and adding in to mu.
297-
298-
If ``source`` is not None, returns only the corresponding block.
294+
where m is the number of LFs, k is the cardinality, and cprobs includes the
295+
conditional abstain probabilities P(\lf_i = -1 | Y = y).
299296
300297
Parameters
301298
----------
302-
source
303-
Index of source to generate conditional probabilities for, by default None
299+
mu
300+
An [m * k, k] np.ndarray with entries in [0, 1]
304301
305302
Returns
306303
-------
307304
np.ndarray
308-
Conditional probabilities table if source is None, else corresponding block
305+
An [m, k + 1, k] np.ndarray conditional probabilities table.
309306
"""
310-
c_probs = np.zeros((self.m * (self.cardinality + 1), self.cardinality))
311-
mu = self.mu.detach().clone().numpy()
312-
307+
cprobs = np.zeros((self.m, self.cardinality + 1, self.cardinality))
313308
for i in range(self.m):
314309
# si = self.c_data[(i,)]['start_index']
315310
# ei = self.c_data[(i,)]['end_index']
316311
# mu_i = mu[si:ei, :]
317312
mu_i = mu[i * self.cardinality : (i + 1) * self.cardinality, :]
318-
c_probs[
319-
i * (self.cardinality + 1) + 1 : (i + 1) * (self.cardinality + 1), :
320-
] = mu_i
313+
cprobs[i, 1:, :] = mu_i
321314

322315
# The 0th row (corresponding to abstains) is the difference between
323-
# the sums of the other rows and one, by law of total prob
324-
c_probs[i * (self.cardinality + 1), :] = 1 - mu_i.sum(axis=0)
316+
# the sums of the other rows and one, by law of total probability
317+
cprobs[i, 0, :] = 1 - mu_i.sum(axis=0)
318+
return cprobs
325319

326-
if source is not None:
327-
return c_probs[
328-
source * (self.cardinality + 1) : (source + 1) * (self.cardinality + 1)
329-
]
330-
else:
331-
return c_probs
320+
def get_conditional_probs(self) -> np.ndarray:
321+
r"""Return the estimated conditional probabilities table.
322+
323+
Return the estimated conditional probabilites table cprobs, where cprobs is an
324+
(m, k+1, k)-dim np.ndarray with:
325+
326+
cprobs[i, j, k] = P(\lf_i = j-1 | Y = k)
327+
328+
where m is the number of LFs, k is the cardinality, and cprobs includes the
329+
conditional abstain probabilities P(\lf_i = -1 | Y = y).
330+
331+
Returns
332+
-------
333+
np.ndarray
334+
An [m, k + 1, k] np.ndarray conditional probabilities table.
335+
"""
336+
return self._get_conditional_probs(self.mu.detach().numpy())
332337

333338
def get_weights(self) -> np.ndarray:
334339
"""Return the vector of learned LF weights for combining LFs.
@@ -347,10 +352,9 @@ def get_weights(self) -> np.ndarray:
347352
array([0.99, 0.99, 0.99])
348353
"""
349354
accs = np.zeros(self.m)
355+
cprobs = self.get_conditional_probs()
350356
for i in range(self.m):
351-
cps = self._get_conditional_probs(source=i)[1:, :]
352-
accs[i] = np.diag(cps @ self.P.numpy()).sum()
353-
357+
accs[i] = np.diag(cprobs[i, 1:, :] @ self.P.numpy()).sum()
354358
return np.clip(accs / self.coverage, 1e-6, 1.0)
355359

356360
def predict_proba(self, L: np.ndarray) -> np.ndarray:
@@ -379,7 +383,7 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray:
379383
L_shift = L + 1 # convert to {0, 1, ..., k}
380384
self._set_constants(L_shift)
381385
L_aug = self._get_augmented_label_matrix(L_shift)
382-
mu = self.mu.detach().clone().numpy()
386+
mu = self.mu.detach().numpy()
383387
jtm = np.ones(L_aug.shape[1])
384388

385389
# Note: We omit abstains, effectively assuming uniform distribution here
@@ -706,6 +710,96 @@ def _update_lr_scheduler(self, step: int) -> None:
706710
if min_lr and self.optimizer.param_groups[0]["lr"] < min_lr:
707711
self.optimizer.param_groups[0]["lr"] = min_lr
708712

713+
def _clamp_params(self) -> None:
714+
"""Clamp the values of the learned parameter vector.
715+
716+
Clamp the entries of self.mu to be in [mu_eps, 1 - mu_eps], where mu_eps is
717+
either set by the user, or defaults to 1 / 10 ** np.ceil(np.log10(self.n)).
718+
719+
Note that if mu_eps is set too high, e.g. in sparse settings where LFs
720+
mostly abstain, this will result in learning conditional probabilities all
721+
equal to mu_eps (and/or 1 - mu_eps)! See issue #1422.
722+
723+
Note: Use user-provided value of mu_eps in train_config, else default to
724+
mu_eps = 1 / 10 ** np.ceil(np.log10(self.n))
725+
this rounding is done to make it more obvious when the parameters have been
726+
clamped.
727+
"""
728+
if self.train_config.mu_eps is not None:
729+
mu_eps = self.train_config.mu_eps
730+
else:
731+
mu_eps = min(0.01, 1 / 10 ** np.ceil(np.log10(self.n)))
732+
self.mu.data = self.mu.clamp(mu_eps, 1 - mu_eps) # type: ignore
733+
734+
def _count_accurate_lfs(self, mu: np.ndarray) -> int:
735+
r"""Count the number of LFs that are estimated to be better than random.
736+
737+
Return the number of LFs are estimated to be more accurate than not when not
738+
abstaining, i.e., where
739+
740+
P(\lf = Y) > P(\lf != Y, \lf != -1).
741+
742+
Parameters
743+
----------
744+
mu
745+
An [m * k, k] np.ndarray with entries in [0, 1]
746+
747+
Returns
748+
-------
749+
int
750+
Number of LFs better than random
751+
"""
752+
P = self.P.numpy()
753+
cprobs = self._get_conditional_probs(mu)
754+
count = 0
755+
for i in range(self.m):
756+
probs = cprobs[i, 1:] @ P
757+
if 2 * np.diagonal(probs).sum() - probs.sum() > 0:
758+
count += 1
759+
return count
760+
761+
def _break_col_permutation_symmetry(self) -> None:
762+
r"""Heuristically choose amongst (possibly) several valid mu values.
763+
764+
If there are several values of mu that equivalently satisfy the optimization
765+
objective, as there often are due to column permutation symmetries, then pick
766+
the solution that trusts the user-written LFs most.
767+
768+
In more detail, suppose that mu satisfies (minimizes) the two loss objectives:
769+
1. O = mu @ P @ mu.T
770+
2. diag(O) = sum(mu @ P, axis=1)
771+
Then any column permutation matrix Z that commutes with P will also equivalently
772+
satisfy these objectives, and thus is an equally valid (symmetric) solution.
773+
Therefore, we select the solution where the most LFs are estimated to be more
774+
accurate than not when not abstaining, i.e., where for the majority of LFs,
775+
776+
P(\lf = Y) > P(\lf != Y, \lf != -1).
777+
778+
This is the standard assumption we have made in algorithmic and theoretical
779+
work to date. Note however that this is not the only possible heuristic /
780+
assumption that we could use, and in practice this may require further
781+
iteration here.
782+
"""
783+
mu = self.mu.detach().numpy()
784+
P = self.P.numpy()
785+
d, k = mu.shape
786+
787+
# Iterate through the possible perumation matrices and track heuristic scores
788+
Zs = []
789+
scores = []
790+
for idxs in permutations(range(k)):
791+
Z = np.eye(k)[:, idxs]
792+
Zs.append(Z)
793+
794+
# If Z and P commute, get heuristic score, else skip
795+
if np.allclose(Z @ P, P @ Z):
796+
scores.append(self._count_accurate_lfs(mu @ Z))
797+
else:
798+
scores.append(-1)
799+
800+
# Set mu according to highest-scoring permutation
801+
self.mu.data = torch.Tensor(mu @ Zs[np.argmax(scores)]) # type: ignore
802+
709803
def fit(
710804
self,
711805
L_train: np.ndarray,
@@ -816,18 +910,9 @@ def fit(
816910
# Update learning rate
817911
self._update_lr_scheduler(epoch)
818912

819-
# Clamp learned parameters
820-
# Note: If mu_eps is set too high, e.g. in sparse settings where LFs
821-
# mostly abstain, this will result in learning conditional probabilities all
822-
# equal to mu_eps (and/or 1 - mu_eps)!
823-
# Note: Use user-provided value, else default to 1 / n', where n' is n rounded
824-
# to the closest power of ten; this rounding is done to make it more obvious
825-
# when the parameters have been clamped.
826-
if self.train_config.mu_eps is not None:
827-
mu_eps = self.train_config.mu_eps
828-
else:
829-
mu_eps = min(0.01, 1 / 10 ** np.ceil(np.log10(self.n)))
830-
self.mu.data = self.mu.clamp(mu_eps, 1 - mu_eps) # type: ignore
913+
# Post-processing operations on mu
914+
self._clamp_params()
915+
self._break_col_permutation_symmetry()
831916

832917
# Return model to eval mode
833918
self.eval()

test/labeling/model/test_label_model.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import pytest
8+
import torch
89
import torch.nn as nn
910
import torch.optim as optim
1011

@@ -154,9 +155,9 @@ def test_augmented_L_construction(self):
154155
def test_conditional_probs(self):
155156
L = np.array([[0, 1, 0], [0, 1, 0]])
156157
label_model = self._set_up_model(L, class_balance=[0.6, 0.4])
157-
probs = label_model._get_conditional_probs()
158-
self.assertLessEqual(probs.max(), 1.0)
159-
self.assertGreaterEqual(probs.min(), 0.0)
158+
cprobs = label_model.get_conditional_probs()
159+
self.assertLessEqual(cprobs.max(), 1.0)
160+
self.assertGreaterEqual(cprobs.min(), 0.0)
160161

161162
def test_get_weight(self):
162163
# set up L matrix
@@ -382,7 +383,68 @@ def test_set_mu_eps(self):
382383
L = np.array([[1, 1, 1], [1, 1, 1]])
383384
label_model = LabelModel(verbose=False)
384385
label_model.fit(L, mu_eps=mu_eps)
385-
self.assertAlmostEqual(label_model._get_conditional_probs(0)[1, 0], mu_eps)
386+
self.assertAlmostEqual(label_model.get_conditional_probs()[0, 1, 0], mu_eps)
387+
388+
def test_count_accurate_lfs(self):
389+
mu = np.array(
390+
[
391+
# LF 0
392+
[0.75, 0.25],
393+
[0.25, 0.75],
394+
# LF 1
395+
[0.25, 0.75],
396+
[0.15, 0.25],
397+
# LF 2
398+
[0.75, 0.25],
399+
[0.25, 0.75],
400+
]
401+
)
402+
403+
# First test: Two "good" LFs
404+
label_model = LabelModel(verbose=False)
405+
label_model._set_class_balance(None, None)
406+
label_model.m = 3
407+
self.assertEqual(label_model._count_accurate_lfs(mu), 2)
408+
409+
# Second test: Now they should all be "good" due to class balance, since we're
410+
# counting accuracy (not conditional probabilities)
411+
label_model = LabelModel(verbose=False)
412+
label_model._set_class_balance([0.9, 0.1], None)
413+
label_model.m = 3
414+
self.assertEqual(label_model._count_accurate_lfs(mu), 3)
415+
416+
def test_symmetry_breaking(self):
417+
mu = np.array(
418+
[
419+
# LF 0
420+
[0.75, 0.25],
421+
[0.25, 0.75],
422+
# LF 1
423+
[0.25, 0.75],
424+
[0.15, 0.25],
425+
# LF 2
426+
[0.75, 0.25],
427+
[0.25, 0.75],
428+
]
429+
)
430+
mu = mu[:, [1, 0]]
431+
432+
# First test: Two "good" LFs
433+
label_model = LabelModel(verbose=False)
434+
label_model._set_class_balance(None, None)
435+
label_model.m = 3
436+
label_model.mu = nn.Parameter(torch.from_numpy(mu))
437+
label_model._break_col_permutation_symmetry()
438+
self.assertEqual(label_model.mu.data[0, 0], 0.75)
439+
440+
# Test with non-uniform class balance
441+
# It should not consider the "correct" permutation as does not commute now
442+
label_model = LabelModel(verbose=False)
443+
label_model._set_class_balance([0.9, 0.1], None)
444+
label_model.m = 3
445+
label_model.mu = nn.Parameter(torch.from_numpy(mu))
446+
label_model._break_col_permutation_symmetry()
447+
self.assertEqual(label_model.mu.data[0, 0], 0.25)
386448

387449

388450
@pytest.mark.complex
@@ -405,9 +467,7 @@ def test_label_model_basic(self) -> None:
405467
label_model.fit(L, n_epochs=200, lr=0.01, seed=123)
406468

407469
# Test estimated LF conditional probabilities
408-
P_lm = label_model._get_conditional_probs().reshape(
409-
(self.m, self.cardinality + 1, -1)
410-
)
470+
P_lm = label_model.get_conditional_probs()
411471
np.testing.assert_array_almost_equal(P, P_lm, decimal=2)
412472

413473
# Test predicted labels
@@ -431,9 +491,7 @@ def test_label_model_sparse(self) -> None:
431491
label_model.fit(L, n_epochs=1000, lr=0.01, seed=123)
432492

433493
# Test estimated LF conditional probabilities
434-
P_lm = label_model._get_conditional_probs().reshape(
435-
(self.m, self.cardinality + 1, -1)
436-
)
494+
P_lm = label_model.get_conditional_probs()
437495
np.testing.assert_array_almost_equal(P, P_lm, decimal=2)
438496

439497
# Test predicted labels *only on non-abstained data points*

0 commit comments

Comments
 (0)