1
1
import logging
2
2
import pickle
3
3
import random
4
- from collections import Counter
5
- from itertools import chain , permutations
6
- from typing import Any , Dict , List , NamedTuple , Optional , Set , Tuple , Union
4
+ from collections import Counter , defaultdict
5
+ from itertools import chain
6
+ from typing import Any , DefaultDict , Dict , List , NamedTuple , Optional , Set , Tuple , Union
7
7
8
8
import numpy as np
9
9
import torch
10
10
import torch .nn as nn
11
11
import torch .optim as optim
12
+ from munkres import Munkres # type: ignore
12
13
13
14
from snorkel .analysis import Scorer
14
15
from snorkel .labeling .analysis import LFAnalysis
@@ -755,33 +756,6 @@ def _clamp_params(self) -> None:
755
756
mu_eps = min (0.01 , 1 / 10 ** np .ceil (np .log10 (self .n )))
756
757
self .mu .data = self .mu .clamp (mu_eps , 1 - mu_eps ) # type: ignore
757
758
758
- def _count_accurate_lfs (self , mu : np .ndarray ) -> int :
759
- r"""Count the number of LFs that are estimated to be better than random.
760
-
761
- Return the number of LFs are estimated to be more accurate than not when not
762
- abstaining, i.e., where
763
-
764
- P(\lf = Y) > P(\lf != Y, \lf != -1).
765
-
766
- Parameters
767
- ----------
768
- mu
769
- An [m * k, k] np.ndarray with entries in [0, 1]
770
-
771
- Returns
772
- -------
773
- int
774
- Number of LFs better than random
775
- """
776
- P = self .P .cpu ().detach ().numpy ()
777
- cprobs = self ._get_conditional_probs (mu )
778
- count = 0
779
- for i in range (self .m ):
780
- probs = cprobs [i , 1 :] @ P
781
- if 2 * np .diagonal (probs ).sum () - probs .sum () > 0 :
782
- count += 1
783
- return count
784
-
785
759
def _break_col_permutation_symmetry (self ) -> None :
786
760
r"""Heuristically choose amongst (possibly) several valid mu values.
787
761
@@ -794,38 +768,41 @@ def _break_col_permutation_symmetry(self) -> None:
794
768
2. diag(O) = sum(mu @ P, axis=1)
795
769
Then any column permutation matrix Z that commutes with P will also equivalently
796
770
satisfy these objectives, and thus is an equally valid (symmetric) solution.
797
- Therefore, we select the solution where the most LFs are estimated to be more
798
- accurate than not when not abstaining, i.e., where for the majority of LFs,
799
-
800
- P(\lf = Y) > P(\lf != Y, \lf != -1).
771
+ Therefore, we select the solution that maximizes the summed probability of the
772
+ LFs being accurate when not abstaining.
801
773
802
- This is the standard assumption we have made in algorithmic and theoretical
803
- work to date. Note however that this is not the only possible heuristic /
804
- assumption that we could use, and in practice this may require further
805
- iteration here.
774
+ \sum_lf \sum_{y=1}^{cardinality} P(\lf = y, Y = y)
806
775
"""
807
776
mu = self .mu .cpu ().detach ().numpy ()
808
777
P = self .P .cpu ().detach ().numpy ()
809
778
d , k = mu .shape
810
-
811
- # Iterate through the possible perumation matrices and track heuristic scores
812
- Zs = []
813
- scores = []
814
- for idxs in permutations (range (k )):
815
- Z = np .eye (k )[:, idxs ]
816
- Zs .append (Z )
817
-
818
- # If Z and P commute, get heuristic score, else skip
819
- if np .allclose (Z @ P , P @ Z ):
820
- scores .append (self ._count_accurate_lfs (mu @ Z ))
821
- else :
822
- scores .append (- 1 )
823
-
824
- # Set mu according to highest-scoring permutation
779
+ # We want to maximize the sum of diagonals of matrices for each LF. So
780
+ # we start by computing the sum of conditional probabilities here.
781
+ probs_sum = sum ([mu [i : i + k ] for i in range (0 , self .m * k , k )]) @ P
782
+
783
+ munkres_solver = Munkres ()
784
+ Z = np .zeros ([k , k ])
785
+
786
+ # Compute groups of indicess with equal prior in P.
787
+ groups : DefaultDict [float , List [int ]] = defaultdict (list )
788
+ for i , f in enumerate (P .diagonal ()):
789
+ groups [np .around (f , 3 )].append (i )
790
+ for group in groups .values ():
791
+ if len (group ) == 1 :
792
+ Z [group [0 ], group [0 ]] = 1.0 # Identity permutation
793
+ continue
794
+ # Compute submatrix corresponding to the group.
795
+ probs_proj = probs_sum [[[g ] for g in group ], group ]
796
+ # Use the Munkres algorithm to find the optimal permutation.
797
+ # We use minus because we want to maximize diagonal sum, not minimize,
798
+ # and transpose because we want to permute columns, not rows.
799
+ permutation_pairs = munkres_solver .compute (- probs_proj .T )
800
+ for i , j in permutation_pairs :
801
+ Z [group [i ], group [j ]] = 1.0
802
+
803
+ # Set mu according to permutation
825
804
self .mu = nn .Parameter (
826
- torch .Tensor (mu @ Zs [np .argmax (scores )]).to ( # type: ignore
827
- self .config .device
828
- )
805
+ torch .Tensor (mu @ Z ).to (self .config .device ) # type: ignore
829
806
)
830
807
831
808
def fit (
0 commit comments