Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion snorkel/analysis/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def metric_score(
preds
An array of (int) predictions
probs
An [n_datapoints, n_classes] array of probabilistic predictions
An [n_datapoints, n_classes] array of probabilistic (float) predictions
metric
The name of the metric to calculate
filter_dict
Expand Down
19 changes: 13 additions & 6 deletions snorkel/analysis/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,25 @@ def __init__(
self.abstain_label = abstain_label

def score(
self, golds: np.ndarray, preds: np.ndarray, probs: np.ndarray
self,
golds: np.ndarray,
preds: Optional[np.ndarray] = None,
probs: Optional[np.ndarray] = None,
) -> Dict[str, float]:
"""Calculate one or more scores from user-specified and/or user-defined metrics.
"""Calculate scores for one or more user-specified metrics.

Parameters
----------
golds
Gold (aka ground truth) labels (integers)
An array of gold (int) labels to base scores on
preds
Predictions (integers)
probs:
Probabilities (floats)
An [n_datapoints,] or [n_datapoints, 1] array of (int) predictions to score
probs
An [n_datapoints, n_classes] array of probabilistic (float) predictions

Because most metrics require either `preds` or `probs`, but not both, these
values are optional; it is up to the metric function that will be called to
raise an exception if a field it requires is not passed to the `score()` method.

Returns
-------
Expand Down
5 changes: 5 additions & 0 deletions test/analysis/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def test_no_labels(self) -> None:
with self.assertRaisesRegex(ValueError, "Cannot score"):
scorer.score([], [], [])

def test_no_probs(self) -> None:
scorer = Scorer()
golds, preds, probs = self._get_labels()
self.assertEqual(scorer.score(golds, preds), scorer.score(golds, preds, probs))

def test_abstain_labels(self) -> None:
# We abstain on the last example by convention (label=-1)
golds = np.array([1, 0, 1, 0, -1])
Expand Down