Skip to content

Commit 056bf91

Browse files
authored
Faster symmetry breaking (#1502)
* Faster symmetry breaking * Address comments
1 parent 36a5456 commit 056bf91

File tree

5 files changed

+69
-78
lines changed

5 files changed

+69
-78
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ tqdm>=4.33.0,<5.0.0
1919
# Internal models
2020
scikit-learn>=0.20.2,<0.22.0
2121
torch>=1.1.0,<1.2.0
22+
munkres==1.1.2
2223

2324
# LF dependency learning
2425
networkx>=2.2,<2.4

scripts/check_requirements.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,11 @@ def parse_setup() -> Tuple[PackagesType, PackagesType, Set[str], Set[str]]:
103103
def main() -> int:
104104
exit_code = 0
105105

106-
requirements_essential, requirements_other, requirements_duplicate = (
107-
parse_requirements()
108-
)
106+
(
107+
requirements_essential,
108+
requirements_other,
109+
requirements_duplicate,
110+
) = parse_requirements()
109111
requirements_all = dict(requirements_essential, **requirements_other)
110112
setup_essential, setup_test, essential_duplicates, test_duplicates = parse_setup()
111113

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
packages=find_packages(),
3636
include_package_data=True,
3737
install_requires=[
38+
"munkres==1.1.2",
3839
"numpy>=1.16.0,<2.0.0",
3940
"scipy>=1.2.0,<2.0.0",
4041
"pandas>=0.25.0,<0.26.0",

snorkel/labeling/model/label_model.py

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import logging
22
import pickle
33
import random
4-
from collections import Counter
5-
from itertools import chain, permutations
6-
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union
4+
from collections import Counter, defaultdict
5+
from itertools import chain
6+
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Set, Tuple, Union
77

88
import numpy as np
99
import torch
1010
import torch.nn as nn
1111
import torch.optim as optim
12+
from munkres import Munkres # type: ignore
1213

1314
from snorkel.analysis import Scorer
1415
from snorkel.labeling.analysis import LFAnalysis
@@ -755,33 +756,6 @@ def _clamp_params(self) -> None:
755756
mu_eps = min(0.01, 1 / 10 ** np.ceil(np.log10(self.n)))
756757
self.mu.data = self.mu.clamp(mu_eps, 1 - mu_eps) # type: ignore
757758

758-
def _count_accurate_lfs(self, mu: np.ndarray) -> int:
759-
r"""Count the number of LFs that are estimated to be better than random.
760-
761-
Return the number of LFs are estimated to be more accurate than not when not
762-
abstaining, i.e., where
763-
764-
P(\lf = Y) > P(\lf != Y, \lf != -1).
765-
766-
Parameters
767-
----------
768-
mu
769-
An [m * k, k] np.ndarray with entries in [0, 1]
770-
771-
Returns
772-
-------
773-
int
774-
Number of LFs better than random
775-
"""
776-
P = self.P.cpu().detach().numpy()
777-
cprobs = self._get_conditional_probs(mu)
778-
count = 0
779-
for i in range(self.m):
780-
probs = cprobs[i, 1:] @ P
781-
if 2 * np.diagonal(probs).sum() - probs.sum() > 0:
782-
count += 1
783-
return count
784-
785759
def _break_col_permutation_symmetry(self) -> None:
786760
r"""Heuristically choose amongst (possibly) several valid mu values.
787761
@@ -794,38 +768,41 @@ def _break_col_permutation_symmetry(self) -> None:
794768
2. diag(O) = sum(mu @ P, axis=1)
795769
Then any column permutation matrix Z that commutes with P will also equivalently
796770
satisfy these objectives, and thus is an equally valid (symmetric) solution.
797-
Therefore, we select the solution where the most LFs are estimated to be more
798-
accurate than not when not abstaining, i.e., where for the majority of LFs,
799-
800-
P(\lf = Y) > P(\lf != Y, \lf != -1).
771+
Therefore, we select the solution that maximizes the summed probability of the
772+
LFs being accurate when not abstaining.
801773
802-
This is the standard assumption we have made in algorithmic and theoretical
803-
work to date. Note however that this is not the only possible heuristic /
804-
assumption that we could use, and in practice this may require further
805-
iteration here.
774+
\sum_lf \sum_{y=1}^{cardinality} P(\lf = y, Y = y)
806775
"""
807776
mu = self.mu.cpu().detach().numpy()
808777
P = self.P.cpu().detach().numpy()
809778
d, k = mu.shape
810-
811-
# Iterate through the possible perumation matrices and track heuristic scores
812-
Zs = []
813-
scores = []
814-
for idxs in permutations(range(k)):
815-
Z = np.eye(k)[:, idxs]
816-
Zs.append(Z)
817-
818-
# If Z and P commute, get heuristic score, else skip
819-
if np.allclose(Z @ P, P @ Z):
820-
scores.append(self._count_accurate_lfs(mu @ Z))
821-
else:
822-
scores.append(-1)
823-
824-
# Set mu according to highest-scoring permutation
779+
# We want to maximize the sum of diagonals of matrices for each LF. So
780+
# we start by computing the sum of conditional probabilities here.
781+
probs_sum = sum([mu[i : i + k] for i in range(0, self.m * k, k)]) @ P
782+
783+
munkres_solver = Munkres()
784+
Z = np.zeros([k, k])
785+
786+
# Compute groups of indicess with equal prior in P.
787+
groups: DefaultDict[float, List[int]] = defaultdict(list)
788+
for i, f in enumerate(P.diagonal()):
789+
groups[np.around(f, 3)].append(i)
790+
for group in groups.values():
791+
if len(group) == 1:
792+
Z[group[0], group[0]] = 1.0 # Identity permutation
793+
continue
794+
# Compute submatrix corresponding to the group.
795+
probs_proj = probs_sum[[[g] for g in group], group]
796+
# Use the Munkres algorithm to find the optimal permutation.
797+
# We use minus because we want to maximize diagonal sum, not minimize,
798+
# and transpose because we want to permute columns, not rows.
799+
permutation_pairs = munkres_solver.compute(-probs_proj.T)
800+
for i, j in permutation_pairs:
801+
Z[group[i], group[j]] = 1.0
802+
803+
# Set mu according to permutation
825804
self.mu = nn.Parameter(
826-
torch.Tensor(mu @ Zs[np.argmax(scores)]).to( # type: ignore
827-
self.config.device
828-
)
805+
torch.Tensor(mu @ Z).to(self.config.device) # type: ignore
829806
)
830807

831808
def fit(

test/labeling/model/test_label_model.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def test_set_mu_eps(self):
417417
label_model.fit(L, mu_eps=mu_eps)
418418
self.assertAlmostEqual(label_model.get_conditional_probs()[0, 1, 0], mu_eps)
419419

420-
def test_count_accurate_lfs(self):
420+
def test_symmetry_breaking(self):
421421
mu = np.array(
422422
[
423423
# LF 0
@@ -431,52 +431,62 @@ def test_count_accurate_lfs(self):
431431
[0.25, 0.75],
432432
]
433433
)
434+
mu = mu[:, [1, 0]]
434435

435436
# First test: Two "good" LFs
436437
label_model = LabelModel(verbose=False)
437438
label_model._set_class_balance(None, None)
438439
label_model.m = 3
439-
self.assertEqual(label_model._count_accurate_lfs(mu), 2)
440+
label_model.mu = nn.Parameter(torch.from_numpy(mu))
441+
label_model._break_col_permutation_symmetry()
442+
self.assertEqual(label_model.mu.data[0, 0], 0.75)
440443

441-
# Second test: Now they should all be "good" due to class balance, since we're
442-
# counting accuracy (not conditional probabilities)
444+
# Test with non-uniform class balance
445+
# It should not consider the "correct" permutation as does not commute now
443446
label_model = LabelModel(verbose=False)
444447
label_model._set_class_balance([0.9, 0.1], None)
445448
label_model.m = 3
446-
self.assertEqual(label_model._count_accurate_lfs(mu), 3)
449+
label_model.mu = nn.Parameter(torch.from_numpy(mu))
450+
label_model._break_col_permutation_symmetry()
451+
self.assertEqual(label_model.mu.data[0, 0], 0.25)
447452

448-
def test_symmetry_breaking(self):
453+
def test_symmetry_breaking_multiclass(self):
449454
mu = np.array(
450455
[
451456
# LF 0
452-
[0.75, 0.25],
453-
[0.25, 0.75],
457+
[0.75, 0.15, 0.1],
458+
[0.20, 0.75, 0.3],
459+
[0.05, 0.10, 0.6],
454460
# LF 1
455-
[0.25, 0.75],
456-
[0.15, 0.25],
461+
[0.25, 0.55, 0.3],
462+
[0.15, 0.45, 0.4],
463+
[0.20, 0.00, 0.3],
457464
# LF 2
458-
[0.75, 0.25],
459-
[0.25, 0.75],
465+
[0.5, 0.15, 0.2],
466+
[0.3, 0.65, 0.2],
467+
[0.2, 0.20, 0.6],
460468
]
461469
)
462-
mu = mu[:, [1, 0]]
470+
mu = mu[:, [1, 2, 0]]
463471

464472
# First test: Two "good" LFs
465-
label_model = LabelModel(verbose=False)
473+
label_model = LabelModel(cardinality=3, verbose=False)
466474
label_model._set_class_balance(None, None)
467475
label_model.m = 3
468476
label_model.mu = nn.Parameter(torch.from_numpy(mu))
469477
label_model._break_col_permutation_symmetry()
470478
self.assertEqual(label_model.mu.data[0, 0], 0.75)
479+
self.assertEqual(label_model.mu.data[1, 1], 0.75)
471480

472481
# Test with non-uniform class balance
473-
# It should not consider the "correct" permutation as does not commute now
474-
label_model = LabelModel(verbose=False)
475-
label_model._set_class_balance([0.9, 0.1], None)
482+
# It should not consider the "correct" permutation as it does not commute
483+
label_model = LabelModel(cardinality=3, verbose=False)
484+
label_model._set_class_balance([0.7, 0.2, 0.1], None)
476485
label_model.m = 3
477486
label_model.mu = nn.Parameter(torch.from_numpy(mu))
478487
label_model._break_col_permutation_symmetry()
479-
self.assertEqual(label_model.mu.data[0, 0], 0.25)
488+
self.assertEqual(label_model.mu.data[0, 0], 0.15)
489+
self.assertEqual(label_model.mu.data[1, 1], 0.3)
480490

481491

482492
@pytest.mark.complex
@@ -528,7 +538,7 @@ def test_label_model_sparse(self) -> None:
528538

529539
# Test predicted labels *only on non-abstained data points*
530540
Y_pred = label_model.predict(L, tie_break_policy="abstain")
531-
idx, = np.where(Y_pred != -1)
541+
(idx,) = np.where(Y_pred != -1)
532542
acc = np.where(Y_pred[idx] == Y[idx], 1, 0).sum() / len(idx)
533543
self.assertGreaterEqual(acc, 0.65)
534544

0 commit comments

Comments
 (0)