Skip to content

Commit 5a6df8b

Browse files
committed
Add fault tolerance to appliers
1 parent e592620 commit 5a6df8b

File tree

7 files changed

+165
-23
lines changed

7 files changed

+165
-23
lines changed

snorkel/labeling/apply/core.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from itertools import chain
2-
from typing import List, Tuple, Union
2+
from typing import DefaultDict, Dict, List, NamedTuple, Tuple, Union
33

44
import numpy as np
55
from tqdm import tqdm
@@ -11,6 +11,28 @@
1111
RowData = List[Tuple[int, int, int]]
1212

1313

14+
class ApplierMetadata(NamedTuple):
15+
"""Metadata about Applier call."""
16+
17+
# Map from LF name to number of faults in apply call
18+
faults: Dict[str, int]
19+
20+
21+
class _FunctionCaller:
22+
def __init__(self, fault_tolerant: bool):
23+
self.fault_tolerant = fault_tolerant
24+
self.fault_counts: DefaultDict[str, int] = DefaultDict(int)
25+
26+
def __call__(self, f: LabelingFunction, x: DataPoint) -> int:
27+
if not self.fault_tolerant:
28+
return f(x)
29+
try:
30+
return f(x)
31+
except Exception:
32+
self.fault_counts[f.name] += 1
33+
return -1
34+
35+
1436
class BaseLFApplier:
1537
"""Base class for LF applier objects.
1638
@@ -60,7 +82,7 @@ def __repr__(self) -> str:
6082

6183

6284
def apply_lfs_to_data_point(
63-
x: DataPoint, index: int, lfs: List[LabelingFunction]
85+
x: DataPoint, index: int, lfs: List[LabelingFunction], f_caller: _FunctionCaller
6486
) -> RowData:
6587
"""Label a single data point with a set of LFs.
6688
@@ -72,6 +94,8 @@ def apply_lfs_to_data_point(
7294
Index of the data point
7395
lfs
7496
Set of LFs to label ``x`` with
97+
f_caller
98+
A ``_FunctionCaller`` to record failed LF executions
7599
76100
Returns
77101
-------
@@ -80,7 +104,7 @@ def apply_lfs_to_data_point(
80104
"""
81105
labels = []
82106
for j, lf in enumerate(lfs):
83-
y = lf(x)
107+
y = f_caller(lf, x)
84108
if y >= 0:
85109
labels.append((index, j, y))
86110
return labels
@@ -114,8 +138,12 @@ class LFApplier(BaseLFApplier):
114138
"""
115139

116140
def apply(
117-
self, data_points: Union[DataPoints, np.ndarray], progress_bar: bool = True
118-
) -> np.ndarray:
141+
self,
142+
data_points: Union[DataPoints, np.ndarray],
143+
progress_bar: bool = True,
144+
fault_tolerant: bool = False,
145+
return_meta: bool = False,
146+
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
119147
"""Label list of data points or a NumPy array with LFs.
120148
121149
Parameters
@@ -124,13 +152,23 @@ def apply(
124152
List of data points or NumPy array to be labeled by LFs
125153
progress_bar
126154
Display a progress bar?
155+
fault_tolerant
156+
Output ``-1`` if LF execution fails?
157+
return_meta
158+
Return metadata from apply call?
127159
128160
Returns
129161
-------
130162
np.ndarray
131163
Matrix of labels emitted by LFs
164+
ApplierMetadata
165+
Metadata, such as fault counts, for the apply call
132166
"""
133167
labels = []
168+
f_caller = _FunctionCaller(fault_tolerant)
134169
for i, x in tqdm(enumerate(data_points), disable=(not progress_bar)):
135-
labels.append(apply_lfs_to_data_point(x, i, self._lfs))
136-
return self._numpy_from_row_data(labels)
170+
labels.append(apply_lfs_to_data_point(x, i, self._lfs, f_caller))
171+
L = self._numpy_from_row_data(labels)
172+
if return_meta:
173+
return L, ApplierMetadata(f_caller.fault_counts)
174+
return L

snorkel/labeling/apply/dask.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dask import dataframe as dd
77
from dask.distributed import Client
88

9-
from .core import BaseLFApplier
9+
from .core import BaseLFApplier, _FunctionCaller
1010
from .pandas import apply_lfs_to_data_point, rows_to_triplets
1111

1212
Scheduler = Union[str, Client]
@@ -20,7 +20,12 @@ class DaskLFApplier(BaseLFApplier):
2020
For more information, see https://docs.dask.org/en/stable/dataframe.html
2121
"""
2222

23-
def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
23+
def apply(
24+
self,
25+
df: dd.DataFrame,
26+
scheduler: Scheduler = "processes",
27+
fault_tolerant: bool = False,
28+
) -> np.ndarray:
2429
"""Label Dask DataFrame of data points with LFs.
2530
2631
Parameters
@@ -31,13 +36,16 @@ def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
3136
A Dask scheduling configuration: either a string option or
3237
a ``Client``. For more information, see
3338
https://docs.dask.org/en/stable/scheduling.html#
39+
fault_tolerant
40+
Output ``-1`` if LF execution fails?
3441
3542
Returns
3643
-------
3744
np.ndarray
3845
Matrix of labels emitted by LFs
3946
"""
40-
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
47+
f_caller = _FunctionCaller(fault_tolerant)
48+
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
4149
map_fn = df.map_partitions(lambda p_df: p_df.apply(apply_fn, axis=1))
4250
labels = map_fn.compute(scheduler=scheduler)
4351
labels_with_index = rows_to_triplets(labels)
@@ -52,7 +60,11 @@ class PandasParallelLFApplier(DaskLFApplier):
5260
"""
5361

5462
def apply( # type: ignore
55-
self, df: pd.DataFrame, n_parallel: int = 2, scheduler: Scheduler = "processes"
63+
self,
64+
df: pd.DataFrame,
65+
n_parallel: int = 2,
66+
scheduler: Scheduler = "processes",
67+
fault_tolerant: bool = False,
5668
) -> np.ndarray:
5769
"""Label Pandas DataFrame of data points with LFs in parallel using Dask.
5870
@@ -69,6 +81,8 @@ def apply( # type: ignore
6981
A Dask scheduling configuration: either a string option or
7082
a ``Client``. For more information, see
7183
https://docs.dask.org/en/stable/scheduling.html#
84+
fault_tolerant
85+
Output ``-1`` if LF execution fails?
7286
7387
Returns
7488
-------
@@ -81,4 +95,4 @@ def apply( # type: ignore
8195
"For single process Pandas, use PandasLFApplier."
8296
)
8397
df = dd.from_pandas(df, npartitions=n_parallel)
84-
return super().apply(df, scheduler=scheduler)
98+
return super().apply(df, scheduler=scheduler, fault_tolerant=fault_tolerant)

snorkel/labeling/apply/pandas.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import List, Tuple
2+
from typing import List, Tuple, Union
33

44
import numpy as np
55
import pandas as pd
@@ -8,12 +8,14 @@
88
from snorkel.labeling.lf import LabelingFunction
99
from snorkel.types import DataPoint
1010

11-
from .core import BaseLFApplier, RowData
11+
from .core import ApplierMetadata, BaseLFApplier, RowData, _FunctionCaller
1212

1313
PandasRowData = List[Tuple[int, int]]
1414

1515

16-
def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> PandasRowData:
16+
def apply_lfs_to_data_point(
17+
x: DataPoint, lfs: List[LabelingFunction], f_caller: _FunctionCaller
18+
) -> PandasRowData:
1719
"""Label a single data point with a set of LFs.
1820
1921
Parameters
@@ -22,6 +24,8 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
2224
Data point to label
2325
lfs
2426
Set of LFs to label ``x`` with
27+
f_caller
28+
A ``_FunctionCaller`` to record failed LF executions
2529
2630
Returns
2731
-------
@@ -30,7 +34,7 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
3034
"""
3135
labels = []
3236
for j, lf in enumerate(lfs):
33-
y = lf(x)
37+
y = f_caller(lf, x)
3438
if y >= 0:
3539
labels.append((j, y))
3640
return labels
@@ -68,7 +72,13 @@ class PandasLFApplier(BaseLFApplier):
6872
array([[0], [1]])
6973
"""
7074

71-
def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
75+
def apply(
76+
self,
77+
df: pd.DataFrame,
78+
progress_bar: bool = True,
79+
fault_tolerant: bool = False,
80+
return_meta: bool = False,
81+
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
7282
"""Label Pandas DataFrame of data points with LFs.
7383
7484
Parameters
@@ -77,17 +87,27 @@ def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
7787
Pandas DataFrame containing data points to be labeled by LFs
7888
progress_bar
7989
Display a progress bar?
90+
fault_tolerant
91+
Output ``-1`` if LF execution fails?
92+
return_meta
93+
Return metadata from apply call?
8094
8195
Returns
8296
-------
8397
np.ndarray
8498
Matrix of labels emitted by LFs
99+
ApplierMetadata
100+
Metadata, such as fault counts, for the apply call
85101
"""
86-
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
102+
f_caller = _FunctionCaller(fault_tolerant)
103+
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
87104
call_fn = df.apply
88105
if progress_bar:
89106
tqdm.pandas()
90107
call_fn = df.progress_apply
91108
labels = call_fn(apply_fn, axis=1)
92109
labels_with_index = rows_to_triplets(labels)
93-
return self._numpy_from_row_data(labels_with_index)
110+
L = self._numpy_from_row_data(labels_with_index)
111+
if return_meta:
112+
return L, ApplierMetadata(f_caller.fault_counts)
113+
return L

snorkel/labeling/apply/spark.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from snorkel.types import DataPoint
77

8-
from .core import BaseLFApplier, RowData, apply_lfs_to_data_point
8+
from .core import BaseLFApplier, RowData, _FunctionCaller, apply_lfs_to_data_point
99

1010

1111
class SparkLFApplier(BaseLFApplier):
@@ -18,22 +18,25 @@ class SparkLFApplier(BaseLFApplier):
1818
``test/labeling/apply/lf_applier_spark_test_script.py``.
1919
"""
2020

21-
def apply(self, data_points: RDD) -> np.ndarray:
21+
def apply(self, data_points: RDD, fault_tolerant: bool = False) -> np.ndarray:
2222
"""Label PySpark RDD of data points with LFs.
2323
2424
Parameters
2525
----------
2626
data_points
2727
PySpark RDD containing data points to be labeled by LFs
28+
fault_tolerant
29+
Output ``-1`` if LF execution fails?
2830
2931
Returns
3032
-------
3133
np.ndarray
3234
Matrix of labels emitted by LFs
3335
"""
36+
f_caller = _FunctionCaller(fault_tolerant)
3437

3538
def map_fn(args: Tuple[DataPoint, int]) -> RowData:
36-
return apply_lfs_to_data_point(*args, lfs=self._lfs)
39+
return apply_lfs_to_data_point(*args, lfs=self._lfs, f_caller=f_caller)
3740

3841
labels = data_points.zipWithIndex().map(map_fn).collect()
3942
return self._numpy_from_row_data(labels)

snorkel/slicing/monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ def slice_dataframe(
2727
S = PandasSFApplier([slicing_function]).apply(df)
2828

2929
# Index into the SF labels by name
30-
df_idx = np.where(S[slicing_function.name])[0]
30+
df_idx = np.where(S[slicing_function.name])[0] # type: ignore
3131
return df.iloc[df_idx]

test/labeling/apply/test_lf_applier.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dask import dataframe as dd
99

1010
from snorkel.labeling import LFApplier, PandasLFApplier, labeling_function
11+
from snorkel.labeling.apply.core import ApplierMetadata
1112
from snorkel.labeling.apply.dask import DaskLFApplier, PandasParallelLFApplier
1213
from snorkel.preprocess import preprocessor
1314
from snorkel.preprocess.nlp import SpacyPreprocessor
@@ -59,8 +60,14 @@ def g_np(x: DataPoint, db: List[int]) -> int:
5960
return 0 if x[1] in db else -1
6061

6162

63+
@labeling_function()
64+
def f_bad(x: DataPoint) -> int:
65+
return 0 if x.mum > 42 else -1
66+
67+
6268
DATA = [3, 43, 12, 9, 3]
6369
L_EXPECTED = np.array([[-1, 0], [0, -1], [-1, -1], [-1, 0], [-1, 0]])
70+
L_EXPECTED_BAD = np.array([[-1, -1], [0, -1], [-1, -1], [-1, -1], [-1, -1]])
6471
L_PREPROCESS_EXPECTED = np.array([[-1, -1], [0, 0], [-1, 0], [-1, 0], [-1, -1]])
6572

6673
TEXT_DATA = ["Jane", "Jane plays soccer.", "Jane plays soccer."]
@@ -75,6 +82,22 @@ def test_lf_applier(self) -> None:
7582
np.testing.assert_equal(L, L_EXPECTED)
7683
L = applier.apply(data_points, progress_bar=True)
7784
np.testing.assert_equal(L, L_EXPECTED)
85+
L, meta = applier.apply(data_points, return_meta=True)
86+
np.testing.assert_equal(L, L_EXPECTED)
87+
self.assertEqual(meta, ApplierMetadata(dict()))
88+
89+
def test_lf_applier_fault(self) -> None:
90+
data_points = [SimpleNamespace(num=num) for num in DATA]
91+
applier = LFApplier([f, f_bad])
92+
with self.assertRaises(AttributeError):
93+
applier.apply(data_points, progress_bar=False)
94+
L = applier.apply(data_points, progress_bar=False, fault_tolerant=True)
95+
np.testing.assert_equal(L, L_EXPECTED_BAD)
96+
L, meta = applier.apply(
97+
data_points, progress_bar=False, fault_tolerant=True, return_meta=True
98+
)
99+
np.testing.assert_equal(L, L_EXPECTED_BAD)
100+
self.assertEqual(meta, ApplierMetadata(dict(f_bad=5)))
78101

79102
def test_lf_applier_preprocessor(self) -> None:
80103
data_points = [SimpleNamespace(num=num) for num in DATA]
@@ -121,6 +144,22 @@ def test_lf_applier_pandas(self) -> None:
121144
np.testing.assert_equal(L, L_EXPECTED)
122145
L = applier.apply(df, progress_bar=True)
123146
np.testing.assert_equal(L, L_EXPECTED)
147+
L, meta = applier.apply(df, return_meta=True)
148+
np.testing.assert_equal(L, L_EXPECTED)
149+
self.assertEqual(meta, ApplierMetadata(dict()))
150+
151+
def test_lf_applier_pandas_fault(self) -> None:
152+
df = pd.DataFrame(dict(num=DATA))
153+
applier = PandasLFApplier([f, f_bad])
154+
with self.assertRaises(AttributeError):
155+
applier.apply(df, progress_bar=False)
156+
L = applier.apply(df, progress_bar=False, fault_tolerant=True)
157+
np.testing.assert_equal(L, L_EXPECTED_BAD)
158+
L, meta = applier.apply(
159+
df, progress_bar=False, fault_tolerant=True, return_meta=True
160+
)
161+
np.testing.assert_equal(L, L_EXPECTED_BAD)
162+
self.assertEqual(meta, ApplierMetadata(dict(f_bad=5)))
124163

125164
def test_lf_applier_pandas_preprocessor(self) -> None:
126165
df = pd.DataFrame(dict(num=DATA))
@@ -189,6 +228,15 @@ def test_lf_applier_dask(self) -> None:
189228
L = applier.apply(df)
190229
np.testing.assert_equal(L, L_EXPECTED)
191230

231+
def test_lf_applier_dask_fault(self) -> None:
232+
df = pd.DataFrame(dict(num=DATA))
233+
df = dd.from_pandas(df, npartitions=2)
234+
applier = DaskLFApplier([f, f_bad])
235+
with self.assertRaises(Exception):
236+
applier.apply(df)
237+
L = applier.apply(df, fault_tolerant=True)
238+
np.testing.assert_equal(L, L_EXPECTED_BAD)
239+
192240
def test_lf_applier_dask_preprocessor(self) -> None:
193241
df = pd.DataFrame(dict(num=DATA))
194242
df = dd.from_pandas(df, npartitions=2)

0 commit comments

Comments
 (0)