Skip to content

Commit b8e47c2

Browse files
committed
Remove fault tolerant mode from LF/SF (#1481)
* Remove fault tolerant mode from LF/SF * Add fault tolerance to appliers (#1480)
1 parent f9a4382 commit b8e47c2

File tree

13 files changed

+175
-108
lines changed

13 files changed

+175
-108
lines changed

snorkel/labeling/apply/core.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from itertools import chain
2-
from typing import List, Tuple, Union
2+
from typing import DefaultDict, Dict, List, NamedTuple, Tuple, Union
33

44
import numpy as np
55
from tqdm import tqdm
@@ -11,6 +11,28 @@
1111
RowData = List[Tuple[int, int, int]]
1212

1313

14+
class ApplierMetadata(NamedTuple):
15+
"""Metadata about Applier call."""
16+
17+
# Map from LF name to number of faults in apply call
18+
faults: Dict[str, int]
19+
20+
21+
class _FunctionCaller:
22+
def __init__(self, fault_tolerant: bool):
23+
self.fault_tolerant = fault_tolerant
24+
self.fault_counts: DefaultDict[str, int] = DefaultDict(int)
25+
26+
def __call__(self, f: LabelingFunction, x: DataPoint) -> int:
27+
if not self.fault_tolerant:
28+
return f(x)
29+
try:
30+
return f(x)
31+
except Exception:
32+
self.fault_counts[f.name] += 1
33+
return -1
34+
35+
1436
class BaseLFApplier:
1537
"""Base class for LF applier objects.
1638
@@ -60,7 +82,7 @@ def __repr__(self) -> str:
6082

6183

6284
def apply_lfs_to_data_point(
63-
x: DataPoint, index: int, lfs: List[LabelingFunction]
85+
x: DataPoint, index: int, lfs: List[LabelingFunction], f_caller: _FunctionCaller
6486
) -> RowData:
6587
"""Label a single data point with a set of LFs.
6688
@@ -72,6 +94,8 @@ def apply_lfs_to_data_point(
7294
Index of the data point
7395
lfs
7496
Set of LFs to label ``x`` with
97+
f_caller
98+
A ``_FunctionCaller`` to record failed LF executions
7599
76100
Returns
77101
-------
@@ -80,7 +104,7 @@ def apply_lfs_to_data_point(
80104
"""
81105
labels = []
82106
for j, lf in enumerate(lfs):
83-
y = lf(x)
107+
y = f_caller(lf, x)
84108
if y >= 0:
85109
labels.append((index, j, y))
86110
return labels
@@ -114,8 +138,12 @@ class LFApplier(BaseLFApplier):
114138
"""
115139

116140
def apply(
117-
self, data_points: Union[DataPoints, np.ndarray], progress_bar: bool = True
118-
) -> np.ndarray:
141+
self,
142+
data_points: Union[DataPoints, np.ndarray],
143+
progress_bar: bool = True,
144+
fault_tolerant: bool = False,
145+
return_meta: bool = False,
146+
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
119147
"""Label list of data points or a NumPy array with LFs.
120148
121149
Parameters
@@ -124,13 +152,23 @@ def apply(
124152
List of data points or NumPy array to be labeled by LFs
125153
progress_bar
126154
Display a progress bar?
155+
fault_tolerant
156+
Output ``-1`` if LF execution fails?
157+
return_meta
158+
Return metadata from apply call?
127159
128160
Returns
129161
-------
130162
np.ndarray
131163
Matrix of labels emitted by LFs
164+
ApplierMetadata
165+
Metadata, such as fault counts, for the apply call
132166
"""
133167
labels = []
168+
f_caller = _FunctionCaller(fault_tolerant)
134169
for i, x in tqdm(enumerate(data_points), disable=(not progress_bar)):
135-
labels.append(apply_lfs_to_data_point(x, i, self._lfs))
136-
return self._numpy_from_row_data(labels)
170+
labels.append(apply_lfs_to_data_point(x, i, self._lfs, f_caller))
171+
L = self._numpy_from_row_data(labels)
172+
if return_meta:
173+
return L, ApplierMetadata(f_caller.fault_counts)
174+
return L

snorkel/labeling/apply/dask.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dask import dataframe as dd
77
from dask.distributed import Client
88

9-
from .core import BaseLFApplier
9+
from .core import BaseLFApplier, _FunctionCaller
1010
from .pandas import apply_lfs_to_data_point, rows_to_triplets
1111

1212
Scheduler = Union[str, Client]
@@ -20,7 +20,12 @@ class DaskLFApplier(BaseLFApplier):
2020
For more information, see https://docs.dask.org/en/stable/dataframe.html
2121
"""
2222

23-
def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
23+
def apply(
24+
self,
25+
df: dd.DataFrame,
26+
scheduler: Scheduler = "processes",
27+
fault_tolerant: bool = False,
28+
) -> np.ndarray:
2429
"""Label Dask DataFrame of data points with LFs.
2530
2631
Parameters
@@ -31,13 +36,16 @@ def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
3136
A Dask scheduling configuration: either a string option or
3237
a ``Client``. For more information, see
3338
https://docs.dask.org/en/stable/scheduling.html#
39+
fault_tolerant
40+
Output ``-1`` if LF execution fails?
3441
3542
Returns
3643
-------
3744
np.ndarray
3845
Matrix of labels emitted by LFs
3946
"""
40-
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
47+
f_caller = _FunctionCaller(fault_tolerant)
48+
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
4149
map_fn = df.map_partitions(lambda p_df: p_df.apply(apply_fn, axis=1))
4250
labels = map_fn.compute(scheduler=scheduler)
4351
labels_with_index = rows_to_triplets(labels)
@@ -52,7 +60,11 @@ class PandasParallelLFApplier(DaskLFApplier):
5260
"""
5361

5462
def apply( # type: ignore
55-
self, df: pd.DataFrame, n_parallel: int = 2, scheduler: Scheduler = "processes"
63+
self,
64+
df: pd.DataFrame,
65+
n_parallel: int = 2,
66+
scheduler: Scheduler = "processes",
67+
fault_tolerant: bool = False,
5668
) -> np.ndarray:
5769
"""Label Pandas DataFrame of data points with LFs in parallel using Dask.
5870
@@ -69,6 +81,8 @@ def apply( # type: ignore
6981
A Dask scheduling configuration: either a string option or
7082
a ``Client``. For more information, see
7183
https://docs.dask.org/en/stable/scheduling.html#
84+
fault_tolerant
85+
Output ``-1`` if LF execution fails?
7286
7387
Returns
7488
-------
@@ -81,4 +95,4 @@ def apply( # type: ignore
8195
"For single process Pandas, use PandasLFApplier."
8296
)
8397
df = dd.from_pandas(df, npartitions=n_parallel)
84-
return super().apply(df, scheduler=scheduler)
98+
return super().apply(df, scheduler=scheduler, fault_tolerant=fault_tolerant)

snorkel/labeling/apply/pandas.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import List, Tuple
2+
from typing import List, Tuple, Union
33

44
import numpy as np
55
import pandas as pd
@@ -8,12 +8,14 @@
88
from snorkel.labeling.lf import LabelingFunction
99
from snorkel.types import DataPoint
1010

11-
from .core import BaseLFApplier, RowData
11+
from .core import ApplierMetadata, BaseLFApplier, RowData, _FunctionCaller
1212

1313
PandasRowData = List[Tuple[int, int]]
1414

1515

16-
def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> PandasRowData:
16+
def apply_lfs_to_data_point(
17+
x: DataPoint, lfs: List[LabelingFunction], f_caller: _FunctionCaller
18+
) -> PandasRowData:
1719
"""Label a single data point with a set of LFs.
1820
1921
Parameters
@@ -22,6 +24,8 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
2224
Data point to label
2325
lfs
2426
Set of LFs to label ``x`` with
27+
f_caller
28+
A ``_FunctionCaller`` to record failed LF executions
2529
2630
Returns
2731
-------
@@ -30,7 +34,7 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
3034
"""
3135
labels = []
3236
for j, lf in enumerate(lfs):
33-
y = lf(x)
37+
y = f_caller(lf, x)
3438
if y >= 0:
3539
labels.append((j, y))
3640
return labels
@@ -68,7 +72,13 @@ class PandasLFApplier(BaseLFApplier):
6872
array([[0], [1]])
6973
"""
7074

71-
def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
75+
def apply(
76+
self,
77+
df: pd.DataFrame,
78+
progress_bar: bool = True,
79+
fault_tolerant: bool = False,
80+
return_meta: bool = False,
81+
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
7282
"""Label Pandas DataFrame of data points with LFs.
7383
7484
Parameters
@@ -77,17 +87,27 @@ def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
7787
Pandas DataFrame containing data points to be labeled by LFs
7888
progress_bar
7989
Display a progress bar?
90+
fault_tolerant
91+
Output ``-1`` if LF execution fails?
92+
return_meta
93+
Return metadata from apply call?
8094
8195
Returns
8296
-------
8397
np.ndarray
8498
Matrix of labels emitted by LFs
99+
ApplierMetadata
100+
Metadata, such as fault counts, for the apply call
85101
"""
86-
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
102+
f_caller = _FunctionCaller(fault_tolerant)
103+
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
87104
call_fn = df.apply
88105
if progress_bar:
89106
tqdm.pandas()
90107
call_fn = df.progress_apply
91108
labels = call_fn(apply_fn, axis=1)
92109
labels_with_index = rows_to_triplets(labels)
93-
return self._numpy_from_row_data(labels_with_index)
110+
L = self._numpy_from_row_data(labels_with_index)
111+
if return_meta:
112+
return L, ApplierMetadata(f_caller.fault_counts)
113+
return L

snorkel/labeling/apply/spark.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from snorkel.types import DataPoint
77

8-
from .core import BaseLFApplier, RowData, apply_lfs_to_data_point
8+
from .core import BaseLFApplier, RowData, _FunctionCaller, apply_lfs_to_data_point
99

1010

1111
class SparkLFApplier(BaseLFApplier):
@@ -18,22 +18,25 @@ class SparkLFApplier(BaseLFApplier):
1818
``test/labeling/apply/lf_applier_spark_test_script.py``.
1919
"""
2020

21-
def apply(self, data_points: RDD) -> np.ndarray:
21+
def apply(self, data_points: RDD, fault_tolerant: bool = False) -> np.ndarray:
2222
"""Label PySpark RDD of data points with LFs.
2323
2424
Parameters
2525
----------
2626
data_points
2727
PySpark RDD containing data points to be labeled by LFs
28+
fault_tolerant
29+
Output ``-1`` if LF execution fails?
2830
2931
Returns
3032
-------
3133
np.ndarray
3234
Matrix of labels emitted by LFs
3335
"""
36+
f_caller = _FunctionCaller(fault_tolerant)
3437

3538
def map_fn(args: Tuple[DataPoint, int]) -> RowData:
36-
return apply_lfs_to_data_point(*args, lfs=self._lfs)
39+
return apply_lfs_to_data_point(*args, lfs=self._lfs, f_caller=f_caller)
3740

3841
labels = data_points.zipWithIndex().map(map_fn).collect()
3942
return self._numpy_from_row_data(labels)

snorkel/labeling/lf/core.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ class LabelingFunction:
2727
Labeling resources passed in to ``f`` via ``kwargs``
2828
pre
2929
Preprocessors to run on data points before LF execution
30-
fault_tolerant
31-
Output ``-1`` if LF execution fails?
3230
3331
Raises
3432
------
@@ -39,8 +37,6 @@ class LabelingFunction:
3937
----------
4038
name
4139
See above
42-
fault_tolerant
43-
See above
4440
"""
4541

4642
def __init__(
@@ -49,10 +45,8 @@ def __init__(
4945
f: Callable[..., int],
5046
resources: Optional[Mapping[str, Any]] = None,
5147
pre: Optional[List[BasePreprocessor]] = None,
52-
fault_tolerant: bool = False,
5348
) -> None:
5449
self.name = name
55-
self.fault_tolerant = fault_tolerant
5650
self._f = f
5751
self._resources = resources or {}
5852
self._pre = pre or []
@@ -67,9 +61,7 @@ def _preprocess_data_point(self, x: DataPoint) -> DataPoint:
6761
def __call__(self, x: DataPoint) -> int:
6862
"""Label data point.
6963
70-
Runs all preprocessors, then passes to LF. If an exception
71-
is encountered and the LF is in fault tolerant mode,
72-
the LF abstains from voting.
64+
Runs all preprocessors, then passes preprocessed data point to LF.
7365
7466
Parameters
7567
----------
@@ -82,11 +74,6 @@ def __call__(self, x: DataPoint) -> int:
8274
Label for data point
8375
"""
8476
x = self._preprocess_data_point(x)
85-
if self.fault_tolerant:
86-
try:
87-
return self._f(x, **self._resources)
88-
except Exception:
89-
return -1
9077
return self._f(x, **self._resources)
9178

9279
def __repr__(self) -> str:
@@ -105,8 +92,6 @@ class labeling_function:
10592
Labeling resources passed in to ``f`` via ``kwargs``
10693
preprocessors
10794
Preprocessors to run on data points before LF execution
108-
fault_tolerant
109-
Output ``-1`` if LF execution fails?
11095
11196
Examples
11297
--------
@@ -132,14 +117,12 @@ def __init__(
132117
name: Optional[str] = None,
133118
resources: Optional[Mapping[str, Any]] = None,
134119
pre: Optional[List[BasePreprocessor]] = None,
135-
fault_tolerant: bool = False,
136120
) -> None:
137121
if callable(name):
138122
raise ValueError("Looks like this decorator is missing parentheses!")
139123
self.name = name
140124
self.resources = resources
141125
self.pre = pre
142-
self.fault_tolerant = fault_tolerant
143126

144127
def __call__(self, f: Callable[..., int]) -> LabelingFunction:
145128
"""Wrap a function to create a ``LabelingFunction``.
@@ -155,10 +138,4 @@ def __call__(self, f: Callable[..., int]) -> LabelingFunction:
155138
New ``LabelingFunction`` executing logic in wrapped function
156139
"""
157140
name = self.name or f.__name__
158-
return LabelingFunction(
159-
name=name,
160-
f=f,
161-
resources=self.resources,
162-
pre=self.pre,
163-
fault_tolerant=self.fault_tolerant,
164-
)
141+
return LabelingFunction(name=name, f=f, resources=self.resources, pre=self.pre)

0 commit comments

Comments
 (0)