Skip to content

Commit e40361a

Browse files
committed
Add fault tolerance to appliers
1 parent adf94c6 commit e40361a

File tree

7 files changed

+161
-23
lines changed

7 files changed

+161
-23
lines changed

snorkel/labeling/apply/core.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from itertools import chain
2-
from typing import List, Tuple, Union
2+
from typing import DefaultDict, Dict, List, NamedTuple, Tuple, Union
33

44
import numpy as np
55
from tqdm import tqdm
@@ -11,6 +11,27 @@
1111
RowData = List[Tuple[int, int, int]]
1212

1313

14+
class ApplierMetadata(NamedTuple):
15+
"""Metadata about Applier call."""
16+
17+
faults: Dict[str, int]
18+
19+
20+
class _FunctionCaller:
21+
def __init__(self, fault_tolerant: bool):
22+
self.fault_tolerant = fault_tolerant
23+
self.fault_counts: DefaultDict[str, int] = DefaultDict(int)
24+
25+
def __call__(self, f: LabelingFunction, x: DataPoint) -> int:
26+
if not self.fault_tolerant:
27+
return f(x)
28+
try:
29+
return f(x)
30+
except Exception:
31+
self.fault_counts[f.name] += 1
32+
return -1
33+
34+
1435
class BaseLFApplier:
1536
"""Base class for LF applier objects.
1637
@@ -60,7 +81,7 @@ def __repr__(self) -> str:
6081

6182

6283
def apply_lfs_to_data_point(
63-
x: DataPoint, index: int, lfs: List[LabelingFunction]
84+
x: DataPoint, index: int, lfs: List[LabelingFunction], f_caller: _FunctionCaller
6485
) -> RowData:
6586
"""Label a single data point with a set of LFs.
6687
@@ -72,6 +93,8 @@ def apply_lfs_to_data_point(
7293
Index of the data point
7394
lfs
7495
Set of LFs to label ``x`` with
96+
f_caller
97+
A ``_FunctionCaller`` to record failed LF executions
7598
7699
Returns
77100
-------
@@ -80,7 +103,7 @@ def apply_lfs_to_data_point(
80103
"""
81104
labels = []
82105
for j, lf in enumerate(lfs):
83-
y = lf(x)
106+
y = f_caller(lf, x)
84107
if y >= 0:
85108
labels.append((index, j, y))
86109
return labels
@@ -114,8 +137,12 @@ class LFApplier(BaseLFApplier):
114137
"""
115138

116139
def apply(
117-
self, data_points: Union[DataPoints, np.ndarray], progress_bar: bool = True
118-
) -> np.ndarray:
140+
self,
141+
data_points: Union[DataPoints, np.ndarray],
142+
progress_bar: bool = True,
143+
fault_tolerant: bool = False,
144+
return_meta: bool = False,
145+
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
119146
"""Label list of data points or a NumPy array with LFs.
120147
121148
Parameters
@@ -124,13 +151,23 @@ def apply(
124151
List of data points or NumPy array to be labeled by LFs
125152
progress_bar
126153
Display a progress bar?
154+
fault_tolerant
155+
Output ``-1`` if LF execution fails?
156+
return_meta
157+
Return metadata from apply call?
127158
128159
Returns
129160
-------
130161
np.ndarray
131162
Matrix of labels emitted by LFs
163+
ApplierMetadata
164+
Metadata, such as fault counts, for the apply call
132165
"""
133166
labels = []
167+
f_caller = _FunctionCaller(fault_tolerant)
134168
for i, x in tqdm(enumerate(data_points), disable=(not progress_bar)):
135-
labels.append(apply_lfs_to_data_point(x, i, self._lfs))
136-
return self._numpy_from_row_data(labels)
169+
labels.append(apply_lfs_to_data_point(x, i, self._lfs, f_caller))
170+
L = self._numpy_from_row_data(labels)
171+
if return_meta:
172+
return L, ApplierMetadata(f_caller.fault_counts)
173+
return L

snorkel/labeling/apply/dask.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dask import dataframe as dd
77
from dask.distributed import Client
88

9-
from .core import BaseLFApplier
9+
from .core import BaseLFApplier, _FunctionCaller
1010
from .pandas import apply_lfs_to_data_point, rows_to_triplets
1111

1212
Scheduler = Union[str, Client]
@@ -20,7 +20,9 @@ class DaskLFApplier(BaseLFApplier):
2020
For more information, see https://docs.dask.org/en/stable/dataframe.html
2121
"""
2222

23-
def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
23+
def apply(
24+
self, df: dd, scheduler: Scheduler = "processes", fault_tolerant: bool = False
25+
) -> np.ndarray:
2426
"""Label Dask DataFrame of data points with LFs.
2527
2628
Parameters
@@ -31,13 +33,16 @@ def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
3133
A Dask scheduling configuration: either a string option or
3234
a ``Client``. For more information, see
3335
https://docs.dask.org/en/stable/scheduling.html#
36+
fault_tolerant
37+
Output ``-1`` if LF execution fails?
3438
3539
Returns
3640
-------
3741
np.ndarray
3842
Matrix of labels emitted by LFs
3943
"""
40-
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
44+
f_caller = _FunctionCaller(fault_tolerant)
45+
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
4146
map_fn = df.map_partitions(lambda p_df: p_df.apply(apply_fn, axis=1))
4247
labels = map_fn.compute(scheduler=scheduler)
4348
labels_with_index = rows_to_triplets(labels)
@@ -52,7 +57,11 @@ class PandasParallelLFApplier(DaskLFApplier):
5257
"""
5358

5459
def apply( # type: ignore
55-
self, df: pd.DataFrame, n_parallel: int = 2, scheduler: Scheduler = "processes"
60+
self,
61+
df: pd.DataFrame,
62+
n_parallel: int = 2,
63+
scheduler: Scheduler = "processes",
64+
fault_tolerant: bool = False,
5665
) -> np.ndarray:
5766
"""Label Pandas DataFrame of data points with LFs in parallel using Dask.
5867
@@ -69,6 +78,8 @@ def apply( # type: ignore
6978
A Dask scheduling configuration: either a string option or
7079
a ``Client``. For more information, see
7180
https://docs.dask.org/en/stable/scheduling.html#
81+
fault_tolerant
82+
Output ``-1`` if LF execution fails?
7283
7384
Returns
7485
-------
@@ -81,4 +92,4 @@ def apply( # type: ignore
8192
"For single process Pandas, use PandasLFApplier."
8293
)
8394
df = dd.from_pandas(df, npartitions=n_parallel)
84-
return super().apply(df, scheduler=scheduler)
95+
return super().apply(df, scheduler=scheduler, fault_tolerant=fault_tolerant)

snorkel/labeling/apply/pandas.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import List, Tuple
2+
from typing import List, Tuple, Union
33

44
import numpy as np
55
import pandas as pd
@@ -8,12 +8,14 @@
88
from snorkel.labeling.lf import LabelingFunction
99
from snorkel.types import DataPoint
1010

11-
from .core import BaseLFApplier, RowData
11+
from .core import ApplierMetadata, BaseLFApplier, RowData, _FunctionCaller
1212

1313
PandasRowData = List[Tuple[int, int]]
1414

1515

16-
def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> PandasRowData:
16+
def apply_lfs_to_data_point(
17+
x: DataPoint, lfs: List[LabelingFunction], f_caller: _FunctionCaller
18+
) -> PandasRowData:
1719
"""Label a single data point with a set of LFs.
1820
1921
Parameters
@@ -22,6 +24,8 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
2224
Data point to label
2325
lfs
2426
Set of LFs to label ``x`` with
27+
f_caller
28+
A ``_FunctionCaller`` to record failed LF executions
2529
2630
Returns
2731
-------
@@ -30,7 +34,7 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
3034
"""
3135
labels = []
3236
for j, lf in enumerate(lfs):
33-
y = lf(x)
37+
y = f_caller(lf, x)
3438
if y >= 0:
3539
labels.append((j, y))
3640
return labels
@@ -68,7 +72,13 @@ class PandasLFApplier(BaseLFApplier):
6872
array([[0], [1]])
6973
"""
7074

71-
def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
75+
def apply(
76+
self,
77+
df: pd.DataFrame,
78+
progress_bar: bool = True,
79+
fault_tolerant: bool = False,
80+
return_meta: bool = False,
81+
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
7282
"""Label Pandas DataFrame of data points with LFs.
7383
7484
Parameters
@@ -77,17 +87,27 @@ def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
7787
Pandas DataFrame containing data points to be labeled by LFs
7888
progress_bar
7989
Display a progress bar?
90+
fault_tolerant
91+
Output ``-1`` if LF execution fails?
92+
return_meta
93+
Return metadata from apply call?
8094
8195
Returns
8296
-------
8397
np.ndarray
8498
Matrix of labels emitted by LFs
99+
ApplierMetadata
100+
Metadata, such as fault counts, for the apply call
85101
"""
86-
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
102+
f_caller = _FunctionCaller(fault_tolerant)
103+
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
87104
call_fn = df.apply
88105
if progress_bar:
89106
tqdm.pandas()
90107
call_fn = df.progress_apply
91108
labels = call_fn(apply_fn, axis=1)
92109
labels_with_index = rows_to_triplets(labels)
93-
return self._numpy_from_row_data(labels_with_index)
110+
L = self._numpy_from_row_data(labels_with_index)
111+
if return_meta:
112+
return L, ApplierMetadata(f_caller.fault_counts)
113+
return L

snorkel/labeling/apply/spark.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from snorkel.types import DataPoint
77

8-
from .core import BaseLFApplier, RowData, apply_lfs_to_data_point
8+
from .core import BaseLFApplier, RowData, _FunctionCaller, apply_lfs_to_data_point
99

1010

1111
class SparkLFApplier(BaseLFApplier):
@@ -18,22 +18,25 @@ class SparkLFApplier(BaseLFApplier):
1818
``test/labeling/apply/lf_applier_spark_test_script.py``.
1919
"""
2020

21-
def apply(self, data_points: RDD) -> np.ndarray:
21+
def apply(self, data_points: RDD, fault_tolerant: bool = False) -> np.ndarray:
2222
"""Label PySpark RDD of data points with LFs.
2323
2424
Parameters
2525
----------
2626
data_points
2727
PySpark RDD containing data points to be labeled by LFs
28+
fault_tolerant
29+
Output ``-1`` if LF execution fails?
2830
2931
Returns
3032
-------
3133
np.ndarray
3234
Matrix of labels emitted by LFs
3335
"""
36+
f_caller = _FunctionCaller(fault_tolerant)
3437

3538
def map_fn(args: Tuple[DataPoint, int]) -> RowData:
36-
return apply_lfs_to_data_point(*args, lfs=self._lfs)
39+
return apply_lfs_to_data_point(*args, lfs=self._lfs, f_caller=f_caller)
3740

3841
labels = data_points.zipWithIndex().map(map_fn).collect()
3942
return self._numpy_from_row_data(labels)

snorkel/slicing/monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ def slice_dataframe(
2727
S = PandasSFApplier([slicing_function]).apply(df)
2828

2929
# Index into the SF labels by name
30-
df_idx = np.where(S[slicing_function.name])[0]
30+
df_idx = np.where(S[slicing_function.name])[0] # type: ignore
3131
return df.iloc[df_idx]

test/labeling/apply/test_lf_applier.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dask import dataframe as dd
99

1010
from snorkel.labeling import LFApplier, PandasLFApplier, labeling_function
11+
from snorkel.labeling.apply.core import ApplierMetadata
1112
from snorkel.labeling.apply.dask import DaskLFApplier, PandasParallelLFApplier
1213
from snorkel.preprocess import preprocessor
1314
from snorkel.preprocess.nlp import SpacyPreprocessor
@@ -59,8 +60,14 @@ def g_np(x: DataPoint, db: List[int]) -> int:
5960
return 0 if x[1] in db else -1
6061

6162

63+
@labeling_function()
64+
def f_bad(x: DataPoint) -> int:
65+
return 0 if x.mum > 42 else -1
66+
67+
6268
DATA = [3, 43, 12, 9, 3]
6369
L_EXPECTED = np.array([[-1, 0], [0, -1], [-1, -1], [-1, 0], [-1, 0]])
70+
L_EXPECTED_BAD = np.array([[-1, -1], [0, -1], [-1, -1], [-1, -1], [-1, -1]])
6471
L_PREPROCESS_EXPECTED = np.array([[-1, -1], [0, 0], [-1, 0], [-1, 0], [-1, -1]])
6572

6673
TEXT_DATA = ["Jane", "Jane plays soccer.", "Jane plays soccer."]
@@ -75,6 +82,22 @@ def test_lf_applier(self) -> None:
7582
np.testing.assert_equal(L, L_EXPECTED)
7683
L = applier.apply(data_points, progress_bar=True)
7784
np.testing.assert_equal(L, L_EXPECTED)
85+
L, meta = applier.apply(data_points, return_meta=True)
86+
np.testing.assert_equal(L, L_EXPECTED)
87+
self.assertEqual(meta, ApplierMetadata(dict()))
88+
89+
def test_lf_applier_fault(self) -> None:
90+
data_points = [SimpleNamespace(num=num) for num in DATA]
91+
applier = LFApplier([f, f_bad])
92+
with self.assertRaises(AttributeError):
93+
applier.apply(data_points, progress_bar=False)
94+
L = applier.apply(data_points, progress_bar=False, fault_tolerant=True)
95+
np.testing.assert_equal(L, L_EXPECTED_BAD)
96+
L, meta = applier.apply(
97+
data_points, progress_bar=False, fault_tolerant=True, return_meta=True
98+
)
99+
np.testing.assert_equal(L, L_EXPECTED_BAD)
100+
self.assertEqual(meta, ApplierMetadata(dict(f_bad=5)))
78101

79102
def test_lf_applier_preprocessor(self) -> None:
80103
data_points = [SimpleNamespace(num=num) for num in DATA]
@@ -121,6 +144,22 @@ def test_lf_applier_pandas(self) -> None:
121144
np.testing.assert_equal(L, L_EXPECTED)
122145
L = applier.apply(df, progress_bar=True)
123146
np.testing.assert_equal(L, L_EXPECTED)
147+
L, meta = applier.apply(df, return_meta=True)
148+
np.testing.assert_equal(L, L_EXPECTED)
149+
self.assertEqual(meta, ApplierMetadata(dict()))
150+
151+
def test_lf_applier_pandas_fault(self) -> None:
152+
df = pd.DataFrame(dict(num=DATA))
153+
applier = PandasLFApplier([f, f_bad])
154+
with self.assertRaises(AttributeError):
155+
applier.apply(df, progress_bar=False)
156+
L = applier.apply(df, progress_bar=False, fault_tolerant=True)
157+
np.testing.assert_equal(L, L_EXPECTED_BAD)
158+
L, meta = applier.apply(
159+
df, progress_bar=False, fault_tolerant=True, return_meta=True
160+
)
161+
np.testing.assert_equal(L, L_EXPECTED_BAD)
162+
self.assertEqual(meta, ApplierMetadata(dict(f_bad=5)))
124163

125164
def test_lf_applier_pandas_preprocessor(self) -> None:
126165
df = pd.DataFrame(dict(num=DATA))
@@ -189,6 +228,15 @@ def test_lf_applier_dask(self) -> None:
189228
L = applier.apply(df)
190229
np.testing.assert_equal(L, L_EXPECTED)
191230

231+
def test_lf_applier_dask_fault(self) -> None:
232+
df = pd.DataFrame(dict(num=DATA))
233+
df = dd.from_pandas(df, npartitions=2)
234+
applier = DaskLFApplier([f, f_bad])
235+
with self.assertRaises(Exception):
236+
applier.apply(df)
237+
L = applier.apply(df, fault_tolerant=True)
238+
np.testing.assert_equal(L, L_EXPECTED_BAD)
239+
192240
def test_lf_applier_dask_preprocessor(self) -> None:
193241
df = pd.DataFrame(dict(num=DATA))
194242
df = dd.from_pandas(df, npartitions=2)

0 commit comments

Comments
 (0)