Skip to content

Commit ed77718

Browse files
Issue #1602 - Add get_label_instances to Analysis (#1608)
Addresses #1602. Added a method to analysis/error_analysis that wraps get_label_buckets functionality. Given a bucket, a NumPy array x of your data, and corresponding y label(s), it will return to you x with only the instances corresponding to that bucket.
1 parent 86a21a2 commit ed77718

File tree

4 files changed

+100
-2
lines changed

4 files changed

+100
-2
lines changed

docs/packages/analysis.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ Generic model analysis utilities shared across Snorkel.
1111

1212
Scorer
1313
get_label_buckets
14+
get_label_instances
1415
metric_score

snorkel/analysis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Generic model analysis utilities shared across Snorkel."""
22

3-
from .error_analysis import get_label_buckets # noqa: F401
3+
from .error_analysis import get_label_buckets, get_label_instances # noqa: F401
44
from .metrics import metric_score # noqa: F401
55
from .scorer import Scorer # noqa: F401

snorkel/analysis/error_analysis.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections import defaultdict
23
from typing import DefaultDict, Dict, List, Tuple
34

@@ -55,3 +56,63 @@ def get_label_buckets(*y: np.ndarray) -> Dict[Tuple[int, ...], np.ndarray]:
5556
for i, labels in enumerate(zip(*y_flat)):
5657
buckets[labels].append(i)
5758
return {k: np.array(v) for k, v in buckets.items()}
59+
60+
61+
def get_label_instances(
62+
bucket: Tuple[int, ...], x: np.ndarray, *y: np.ndarray
63+
) -> np.ndarray:
64+
"""Return instances in x with the specified combination of labels.
65+
66+
Parameters
67+
----------
68+
bucket
69+
A tuple of label values corresponding to which instances from x are returned
70+
x
71+
NumPy array of data instances to be returned
72+
*y
73+
A list of np.ndarray of (int) labels
74+
75+
Returns
76+
-------
77+
np.ndarray
78+
NumPy array of instances from x with the specified combination of labels
79+
80+
Example
81+
-------
82+
A common use case is calling ``get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)``
83+
where ``x`` is a NumPy array of data instances that the labels correspond to,
84+
``Y_gold`` is a list of gold (i.e. ground truth) labels, and
85+
``Y_pred`` is a corresponding list of predicted labels.
86+
87+
>>> import pandas as pd
88+
>>> x = pd.DataFrame(data={'col1': ["this is a string", "a second string", "a third string"], 'col2': ["1", "2", "3"]})
89+
>>> Y_gold = np.array([1, 1, 1])
90+
>>> Y_pred = np.array([1, 0, 0])
91+
>>> bucket = (1, 0)
92+
93+
The returned NumPy array of data instances from ``x`` will correspond to
94+
the rows where the first list had a 1 and the second list had a 0.
95+
>>> get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)
96+
array([['a second string', '2'],
97+
['a third string', '3']], dtype=object)
98+
99+
More generally, given bucket ``(i, j, ...)`` and lists ``y1, y2, ...``
100+
the returned data instances from ``x`` will correspond to the rows where
101+
y1 had label i, y2 had label j, and so on. Note that ``x`` and ``y``
102+
must all be the same length.
103+
"""
104+
if len(y) != len(bucket):
105+
raise ValueError("Number of lists must match the amount of labels in bucket")
106+
if x.shape[0] != len(y[0]):
107+
# Note: the check for all y having the same number of elements occurs in get_label_buckets
108+
raise ValueError(
109+
"Number of rows in x does not match number of elements in at least one label list"
110+
)
111+
buckets = get_label_buckets(*y)
112+
try:
113+
indices = buckets[bucket]
114+
except KeyError:
115+
logging.warning("Bucket" + str(bucket) + " does not exist.")
116+
return np.array([])
117+
instances = x[indices]
118+
return instances

test/analysis/test_error_analysis.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from snorkel.analysis import get_label_buckets
5+
from snorkel.analysis import get_label_buckets, get_label_instances
66

77

88
class ErrorAnalysisTest(unittest.TestCase):
@@ -37,6 +37,42 @@ def test_get_label_buckets_bad_shape(self) -> None:
3737
with self.assertRaisesRegex(ValueError, "same number of elements"):
3838
get_label_buckets(np.array([0, 1, 1]), np.array([1, 1]))
3939

40+
def test_get_label_instances(self) -> None:
41+
x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
42+
y1 = np.array([1, 0, 0, 0])
43+
y2 = np.array([1, 1, 1, 0])
44+
instances = get_label_instances((0, 1), x, y1, y2)
45+
expected_instances = np.array([[3, 4], [5, 6]])
46+
np.testing.assert_equal(instances, expected_instances)
47+
48+
x = np.array(["this", "is", "a", "test", "of", "multi"])
49+
y1 = np.array([[2], [1], [3], [1], [1], [3]])
50+
y2 = np.array([1, 2, 3, 1, 2, 3])
51+
y3 = np.array([[3], [2], [1], [1], [2], [3]])
52+
instances = get_label_instances((3, 3, 3), x, y1, y2, y3)
53+
expected_instances = np.array(["multi"])
54+
np.testing.assert_equal(instances, expected_instances)
55+
56+
def test_get_label_instances_exceptions(self) -> None:
57+
x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
58+
y1 = np.array([1, 0, 0, 0])
59+
y2 = np.array([1, 1, 1, 0])
60+
instances = get_label_instances((2, 0), x, y1, y2)
61+
expected_instances = np.array([])
62+
np.testing.assert_equal(instances, expected_instances)
63+
64+
with self.assertRaisesRegex(
65+
ValueError, "Number of lists must match the amount of labels in bucket"
66+
):
67+
get_label_instances((1, 0), x, y1)
68+
69+
x = np.array([[1, 2], [3, 4], [5, 6]])
70+
with self.assertRaisesRegex(
71+
ValueError,
72+
"Number of rows in x does not match number of elements in at least one label list",
73+
):
74+
get_label_instances((1, 0), x, y1, y2)
75+
4076

4177
if __name__ == "__main__":
4278
unittest.main()

0 commit comments

Comments
 (0)