Skip to content

Commit e49ecb2

Browse files
committed
Fix to scoring function for unlabeled = skip
1 parent f48affd commit e49ecb2

File tree

1 file changed

+53
-28
lines changed

1 file changed

+53
-28
lines changed

snorkel/learning.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import scipy.sparse as sparse
33
from scipy.optimize import minimize
4-
from learning_utils import score, sparse_abs
4+
from learning_utils import score, sparse_abs, calibration_plots
55
from lstm import LSTMModel
66
from sklearn import linear_model
77
from .models import Parameter, ParameterSet
@@ -97,33 +97,34 @@ def predict(self, X, b=0.5):
9797
"""Return numpy array of elements in {-1,0,1} based on predicted marginal probabilities."""
9898
return np.array([1 if p > b else -1 if p < b else 0 for p in self.marginals(X)])
9999

100-
def score(self, X_test, L_test, gold_candidate_set, b=0.5, set_unlabeled_as_neg=True, display=True):
100+
def score(self, X_test, L_test, gold_candidate_set=None, b=0.5, set_unlabeled_as_neg=True, display=True):
101101
if L_test.shape[1] != 1:
102102
raise ValueError("L_test must have exactly one column.")
103-
predict = self.predict(X_test, b=b)
104-
train_marginals = self.marginals(self.X_train) if hasattr(self, 'X_train') and self.X_train is not None else None
105-
test_marginals = self.marginals(X_test)
106-
107-
test_candidates = set()
108-
test_labels = []
103+
predict = self.predict(X_test, b=b)
104+
train_marginals = self.marginals(self.X_train) if hasattr(self, 'X_train') and self.X_train is not None else None
105+
test_marginals = self.marginals(X_test)
106+
test_candidates = set()
107+
test_labels = []
108+
test_predictions = []
109+
110+
# Collect error buckets
109111
tp = set()
110112
fp = set()
111113
tn = set()
112114
fn = set()
113-
114115
for i in range(X_test.shape[0]):
115116
candidate = X_test.get_candidate(i)
116117
test_candidates.add(candidate)
117-
try:
118-
L_test_index = L_test.get_row_index(candidate)
119-
test_label = L_test[L_test_index, 0]
120-
121-
# Set unlabeled examples to -1 by default
122-
if test_label == 0 and set_unlabeled_as_neg:
123-
test_label = -1
124-
125-
# Bucket the candidates for error analysis
126-
test_labels.append(test_label)
118+
L_test_index = L_test.get_row_index(candidate)
119+
test_label = L_test[L_test_index, 0]
120+
121+
# Set unlabeled examples to -1 by default
122+
if test_label == 0 and set_unlabeled_as_neg:
123+
test_label = -1
124+
125+
# Bucket the candidates for error analysis
126+
test_labels.append(test_label)
127+
if test_label != 0:
127128
if test_marginals[i] > b:
128129
if test_label == 1:
129130
tp.add(candidate)
@@ -134,17 +135,22 @@ def score(self, X_test, L_test, gold_candidate_set, b=0.5, set_unlabeled_as_neg=
134135
tn.add(candidate)
135136
else:
136137
fn.add(candidate)
137-
except KeyError:
138-
test_labels.append(-1)
139-
if test_marginals[i] > b:
140-
fp.add(candidate)
141-
else:
142-
tn.add(candidate)
143138

144-
# Print diagnostics chart and return error analysis candidate sets
145139
if display:
146-
score(test_candidates, np.asarray(test_labels), np.asarray(predict), gold_candidate_set,
147-
train_marginals=train_marginals, test_marginals=test_marginals)
140+
141+
# Calculate scores unadjusted for TPs not in our candidate set
142+
print_scores(len(tp), len(fp), len(tn), len(fn), title="Scores (Un-adjusted)")
143+
144+
# If a gold candidate set is provided, also calculate recall-adjusted scores
145+
if gold_candidate_set is not None:
146+
gold_fn = [c for c in gold_candidate_set if c not in test_candidates]
147+
print "\n"
148+
print_scores(len(tp), len(fp), len(tn), len(fn)+len(gold_fn), title="Corpus Recall-adjusted Scores")
149+
150+
# If training and test marginals provided, also print calibration plots
151+
if train_marginals is not None and test_marginals is not None:
152+
print "\nCalibration plot:"
153+
calibration_plots(train_marginals, test_marginals, np.asarray(test_labels))
148154
return tp, fp, tn, fn
149155

150156
def save(self, session, param_set_name):
@@ -171,6 +177,25 @@ def load(self, session, param_set_name):
171177
self.w = np.array([res[0] for res in q.all()])
172178

173179

180+
def print_scores(ntp, nfp, ntn, nfn, title='Scores'):
181+
prec = ntp / float(ntp + nfp) if ntp + nfp > 0 else 0.0
182+
rec = ntp / float(ntp + nfn) if ntp + nfn > 0 else 0.0
183+
f1 = (2 * prec * rec) / (prec + rec) if prec + rec > 0 else 0.0
184+
pos_acc = ntp / float(ntp + nfn) if ntp + nfn > 0 else 0.0
185+
neg_acc = ntn / float(ntn + nfp) if ntn + nfp > 0 else 0.0
186+
print "========================================"
187+
print title
188+
print "========================================"
189+
print "Pos. class accuracy: {:.3}".format(pos_acc)
190+
print "Neg. class accuracy: {:.3}".format(neg_acc)
191+
print "Precision {:.3}".format(prec)
192+
print "Recall {:.3}".format(rec)
193+
print "F1 {:.3}".format(f1)
194+
print "----------------------------------------"
195+
print "TP: {} | FP: {} | TN: {} | FN: {}".format(ntp, nfp, ntn, nfn)
196+
print "========================================\n"
197+
198+
174199
class LogRegSKLearn(NoiseAwareModel):
175200
"""Logistic regression."""
176201
def __init__(self):

0 commit comments

Comments
 (0)