Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions snorkel/labeling/apply/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain
from typing import List, Tuple, Union
from typing import DefaultDict, Dict, List, NamedTuple, Tuple, Union

import numpy as np
from tqdm import tqdm
Expand All @@ -11,6 +11,28 @@
RowData = List[Tuple[int, int, int]]


class ApplierMetadata(NamedTuple):
"""Metadata about Applier call."""

# Map from LF name to number of faults in apply call
faults: Dict[str, int]


class _FunctionCaller:
def __init__(self, fault_tolerant: bool):
self.fault_tolerant = fault_tolerant
self.fault_counts: DefaultDict[str, int] = DefaultDict(int)

def __call__(self, f: LabelingFunction, x: DataPoint) -> int:
if not self.fault_tolerant:
return f(x)
try:
return f(x)
except Exception:
self.fault_counts[f.name] += 1
return -1


class BaseLFApplier:
"""Base class for LF applier objects.

Expand Down Expand Up @@ -60,7 +82,7 @@ def __repr__(self) -> str:


def apply_lfs_to_data_point(
x: DataPoint, index: int, lfs: List[LabelingFunction]
x: DataPoint, index: int, lfs: List[LabelingFunction], f_caller: _FunctionCaller
) -> RowData:
"""Label a single data point with a set of LFs.

Expand All @@ -72,6 +94,8 @@ def apply_lfs_to_data_point(
Index of the data point
lfs
Set of LFs to label ``x`` with
f_caller
A ``_FunctionCaller`` to record failed LF executions

Returns
-------
Expand All @@ -80,7 +104,7 @@ def apply_lfs_to_data_point(
"""
labels = []
for j, lf in enumerate(lfs):
y = lf(x)
y = f_caller(lf, x)
if y >= 0:
labels.append((index, j, y))
return labels
Expand Down Expand Up @@ -114,8 +138,12 @@ class LFApplier(BaseLFApplier):
"""

def apply(
self, data_points: Union[DataPoints, np.ndarray], progress_bar: bool = True
) -> np.ndarray:
self,
data_points: Union[DataPoints, np.ndarray],
progress_bar: bool = True,
fault_tolerant: bool = False,
return_meta: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
"""Label list of data points or a NumPy array with LFs.

Parameters
Expand All @@ -124,13 +152,23 @@ def apply(
List of data points or NumPy array to be labeled by LFs
progress_bar
Display a progress bar?
fault_tolerant
Output ``-1`` if LF execution fails?
return_meta
Return metadata from apply call?

Returns
-------
np.ndarray
Matrix of labels emitted by LFs
ApplierMetadata
Metadata, such as fault counts, for the apply call
"""
labels = []
f_caller = _FunctionCaller(fault_tolerant)
for i, x in tqdm(enumerate(data_points), disable=(not progress_bar)):
labels.append(apply_lfs_to_data_point(x, i, self._lfs))
return self._numpy_from_row_data(labels)
labels.append(apply_lfs_to_data_point(x, i, self._lfs, f_caller))
L = self._numpy_from_row_data(labels)
if return_meta:
return L, ApplierMetadata(f_caller.fault_counts)
return L
24 changes: 19 additions & 5 deletions snorkel/labeling/apply/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dask import dataframe as dd
from dask.distributed import Client

from .core import BaseLFApplier
from .core import BaseLFApplier, _FunctionCaller
from .pandas import apply_lfs_to_data_point, rows_to_triplets

Scheduler = Union[str, Client]
Expand All @@ -20,7 +20,12 @@ class DaskLFApplier(BaseLFApplier):
For more information, see https://docs.dask.org/en/stable/dataframe.html
"""

def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
def apply(
self,
df: dd.DataFrame,
scheduler: Scheduler = "processes",
fault_tolerant: bool = False,
) -> np.ndarray:
"""Label Dask DataFrame of data points with LFs.

Parameters
Expand All @@ -31,13 +36,16 @@ def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
A Dask scheduling configuration: either a string option or
a ``Client``. For more information, see
https://docs.dask.org/en/stable/scheduling.html#
fault_tolerant
Output ``-1`` if LF execution fails?

Returns
-------
np.ndarray
Matrix of labels emitted by LFs
"""
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
f_caller = _FunctionCaller(fault_tolerant)
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
map_fn = df.map_partitions(lambda p_df: p_df.apply(apply_fn, axis=1))
labels = map_fn.compute(scheduler=scheduler)
labels_with_index = rows_to_triplets(labels)
Expand All @@ -52,7 +60,11 @@ class PandasParallelLFApplier(DaskLFApplier):
"""

def apply( # type: ignore
self, df: pd.DataFrame, n_parallel: int = 2, scheduler: Scheduler = "processes"
self,
df: pd.DataFrame,
n_parallel: int = 2,
scheduler: Scheduler = "processes",
fault_tolerant: bool = False,
) -> np.ndarray:
"""Label Pandas DataFrame of data points with LFs in parallel using Dask.

Expand All @@ -69,6 +81,8 @@ def apply( # type: ignore
A Dask scheduling configuration: either a string option or
a ``Client``. For more information, see
https://docs.dask.org/en/stable/scheduling.html#
fault_tolerant
Output ``-1`` if LF execution fails?

Returns
-------
Expand All @@ -81,4 +95,4 @@ def apply( # type: ignore
"For single process Pandas, use PandasLFApplier."
)
df = dd.from_pandas(df, npartitions=n_parallel)
return super().apply(df, scheduler=scheduler)
return super().apply(df, scheduler=scheduler, fault_tolerant=fault_tolerant)
34 changes: 27 additions & 7 deletions snorkel/labeling/apply/pandas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import List, Tuple
from typing import List, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -8,12 +8,14 @@
from snorkel.labeling.lf import LabelingFunction
from snorkel.types import DataPoint

from .core import BaseLFApplier, RowData
from .core import ApplierMetadata, BaseLFApplier, RowData, _FunctionCaller

PandasRowData = List[Tuple[int, int]]


def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> PandasRowData:
def apply_lfs_to_data_point(
x: DataPoint, lfs: List[LabelingFunction], f_caller: _FunctionCaller
) -> PandasRowData:
"""Label a single data point with a set of LFs.

Parameters
Expand All @@ -22,6 +24,8 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
Data point to label
lfs
Set of LFs to label ``x`` with
f_caller
A ``_FunctionCaller`` to record failed LF executions

Returns
-------
Expand All @@ -30,7 +34,7 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
"""
labels = []
for j, lf in enumerate(lfs):
y = lf(x)
y = f_caller(lf, x)
if y >= 0:
labels.append((j, y))
return labels
Expand Down Expand Up @@ -68,7 +72,13 @@ class PandasLFApplier(BaseLFApplier):
array([[0], [1]])
"""

def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
def apply(
self,
df: pd.DataFrame,
progress_bar: bool = True,
fault_tolerant: bool = False,
return_meta: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
"""Label Pandas DataFrame of data points with LFs.

Parameters
Expand All @@ -77,17 +87,27 @@ def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
Pandas DataFrame containing data points to be labeled by LFs
progress_bar
Display a progress bar?
fault_tolerant
Output ``-1`` if LF execution fails?
return_meta
Return metadata from apply call?

Returns
-------
np.ndarray
Matrix of labels emitted by LFs
ApplierMetadata
Metadata, such as fault counts, for the apply call
"""
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
f_caller = _FunctionCaller(fault_tolerant)
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
call_fn = df.apply
if progress_bar:
tqdm.pandas()
call_fn = df.progress_apply
labels = call_fn(apply_fn, axis=1)
labels_with_index = rows_to_triplets(labels)
return self._numpy_from_row_data(labels_with_index)
L = self._numpy_from_row_data(labels_with_index)
if return_meta:
return L, ApplierMetadata(f_caller.fault_counts)
return L
9 changes: 6 additions & 3 deletions snorkel/labeling/apply/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from snorkel.types import DataPoint

from .core import BaseLFApplier, RowData, apply_lfs_to_data_point
from .core import BaseLFApplier, RowData, _FunctionCaller, apply_lfs_to_data_point


class SparkLFApplier(BaseLFApplier):
Expand All @@ -18,22 +18,25 @@ class SparkLFApplier(BaseLFApplier):
``test/labeling/apply/lf_applier_spark_test_script.py``.
"""

def apply(self, data_points: RDD) -> np.ndarray:
def apply(self, data_points: RDD, fault_tolerant: bool = False) -> np.ndarray:
"""Label PySpark RDD of data points with LFs.

Parameters
----------
data_points
PySpark RDD containing data points to be labeled by LFs
fault_tolerant
Output ``-1`` if LF execution fails?

Returns
-------
np.ndarray
Matrix of labels emitted by LFs
"""
f_caller = _FunctionCaller(fault_tolerant)

def map_fn(args: Tuple[DataPoint, int]) -> RowData:
return apply_lfs_to_data_point(*args, lfs=self._lfs)
return apply_lfs_to_data_point(*args, lfs=self._lfs, f_caller=f_caller)

labels = data_points.zipWithIndex().map(map_fn).collect()
return self._numpy_from_row_data(labels)
2 changes: 1 addition & 1 deletion snorkel/slicing/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ def slice_dataframe(
S = PandasSFApplier([slicing_function]).apply(df)

# Index into the SF labels by name
df_idx = np.where(S[slicing_function.name])[0]
df_idx = np.where(S[slicing_function.name])[0] # type: ignore
return df.iloc[df_idx]
48 changes: 48 additions & 0 deletions test/labeling/apply/test_lf_applier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dask import dataframe as dd

from snorkel.labeling import LFApplier, PandasLFApplier, labeling_function
from snorkel.labeling.apply.core import ApplierMetadata
from snorkel.labeling.apply.dask import DaskLFApplier, PandasParallelLFApplier
from snorkel.preprocess import preprocessor
from snorkel.preprocess.nlp import SpacyPreprocessor
Expand Down Expand Up @@ -59,8 +60,14 @@ def g_np(x: DataPoint, db: List[int]) -> int:
return 0 if x[1] in db else -1


@labeling_function()
def f_bad(x: DataPoint) -> int:
return 0 if x.mum > 42 else -1


DATA = [3, 43, 12, 9, 3]
L_EXPECTED = np.array([[-1, 0], [0, -1], [-1, -1], [-1, 0], [-1, 0]])
L_EXPECTED_BAD = np.array([[-1, -1], [0, -1], [-1, -1], [-1, -1], [-1, -1]])
L_PREPROCESS_EXPECTED = np.array([[-1, -1], [0, 0], [-1, 0], [-1, 0], [-1, -1]])

TEXT_DATA = ["Jane", "Jane plays soccer.", "Jane plays soccer."]
Expand All @@ -75,6 +82,22 @@ def test_lf_applier(self) -> None:
np.testing.assert_equal(L, L_EXPECTED)
L = applier.apply(data_points, progress_bar=True)
np.testing.assert_equal(L, L_EXPECTED)
L, meta = applier.apply(data_points, return_meta=True)
np.testing.assert_equal(L, L_EXPECTED)
self.assertEqual(meta, ApplierMetadata(dict()))

def test_lf_applier_fault(self) -> None:
data_points = [SimpleNamespace(num=num) for num in DATA]
applier = LFApplier([f, f_bad])
with self.assertRaises(AttributeError):
applier.apply(data_points, progress_bar=False)
L = applier.apply(data_points, progress_bar=False, fault_tolerant=True)
np.testing.assert_equal(L, L_EXPECTED_BAD)
L, meta = applier.apply(
data_points, progress_bar=False, fault_tolerant=True, return_meta=True
)
np.testing.assert_equal(L, L_EXPECTED_BAD)
self.assertEqual(meta, ApplierMetadata(dict(f_bad=5)))

def test_lf_applier_preprocessor(self) -> None:
data_points = [SimpleNamespace(num=num) for num in DATA]
Expand Down Expand Up @@ -121,6 +144,22 @@ def test_lf_applier_pandas(self) -> None:
np.testing.assert_equal(L, L_EXPECTED)
L = applier.apply(df, progress_bar=True)
np.testing.assert_equal(L, L_EXPECTED)
L, meta = applier.apply(df, return_meta=True)
np.testing.assert_equal(L, L_EXPECTED)
self.assertEqual(meta, ApplierMetadata(dict()))

def test_lf_applier_pandas_fault(self) -> None:
df = pd.DataFrame(dict(num=DATA))
applier = PandasLFApplier([f, f_bad])
with self.assertRaises(AttributeError):
applier.apply(df, progress_bar=False)
L = applier.apply(df, progress_bar=False, fault_tolerant=True)
np.testing.assert_equal(L, L_EXPECTED_BAD)
L, meta = applier.apply(
df, progress_bar=False, fault_tolerant=True, return_meta=True
)
np.testing.assert_equal(L, L_EXPECTED_BAD)
self.assertEqual(meta, ApplierMetadata(dict(f_bad=5)))

def test_lf_applier_pandas_preprocessor(self) -> None:
df = pd.DataFrame(dict(num=DATA))
Expand Down Expand Up @@ -189,6 +228,15 @@ def test_lf_applier_dask(self) -> None:
L = applier.apply(df)
np.testing.assert_equal(L, L_EXPECTED)

def test_lf_applier_dask_fault(self) -> None:
df = pd.DataFrame(dict(num=DATA))
df = dd.from_pandas(df, npartitions=2)
applier = DaskLFApplier([f, f_bad])
with self.assertRaises(Exception):
applier.apply(df)
L = applier.apply(df, fault_tolerant=True)
np.testing.assert_equal(L, L_EXPECTED_BAD)

def test_lf_applier_dask_preprocessor(self) -> None:
df = pd.DataFrame(dict(num=DATA))
df = dd.from_pandas(df, npartitions=2)
Expand Down
Loading