Skip to content

Commit 0d9f957

Browse files
authored
Make preds and probs optional in Scorer.score() (#1441)
* Make preds and probs optional in Scorer.score() * Update docstrings
1 parent 2e7a116 commit 0d9f957

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

snorkel/analysis/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def metric_score(
3030
preds
3131
An array of (int) predictions
3232
probs
33-
An [n_datapoints, n_classes] array of probabilistic predictions
33+
An [n_datapoints, n_classes] array of probabilistic (float) predictions
3434
metric
3535
The name of the metric to calculate
3636
filter_dict

snorkel/analysis/scorer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,25 @@ def __init__(
6363
self.abstain_label = abstain_label
6464

6565
def score(
66-
self, golds: np.ndarray, preds: np.ndarray, probs: np.ndarray
66+
self,
67+
golds: np.ndarray,
68+
preds: Optional[np.ndarray] = None,
69+
probs: Optional[np.ndarray] = None,
6770
) -> Dict[str, float]:
68-
"""Calculate one or more scores from user-specified and/or user-defined metrics.
71+
"""Calculate scores for one or more user-specified metrics.
6972
7073
Parameters
7174
----------
7275
golds
73-
Gold (aka ground truth) labels (integers)
76+
An array of gold (int) labels to base scores on
7477
preds
75-
Predictions (integers)
76-
probs:
77-
Probabilities (floats)
78+
An [n_datapoints,] or [n_datapoints, 1] array of (int) predictions to score
79+
probs
80+
An [n_datapoints, n_classes] array of probabilistic (float) predictions
81+
82+
Because most metrics require either `preds` or `probs`, but not both, these
83+
values are optional; it is up to the metric function that will be called to
84+
raise an exception if a field it requires is not passed to the `score()` method.
7885
7986
Returns
8087
-------

test/analysis/test_scorer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def test_no_labels(self) -> None:
5151
with self.assertRaisesRegex(ValueError, "Cannot score"):
5252
scorer.score([], [], [])
5353

54+
def test_no_probs(self) -> None:
55+
scorer = Scorer()
56+
golds, preds, probs = self._get_labels()
57+
self.assertEqual(scorer.score(golds, preds), scorer.score(golds, preds, probs))
58+
5459
def test_abstain_labels(self) -> None:
5560
# We abstain on the last example by convention (label=-1)
5661
golds = np.array([1, 0, 1, 0, -1])

0 commit comments

Comments
 (0)