Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions snorkel/labeling/apply/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain
from typing import List, Tuple, Union
from typing import DefaultDict, Dict, List, NamedTuple, Tuple, Union

import numpy as np
from tqdm import tqdm
Expand All @@ -11,6 +11,28 @@
RowData = List[Tuple[int, int, int]]


class ApplierMetadata(NamedTuple):
"""Metadata about Applier call."""

# Map from LF name to number of faults in apply call
faults: Dict[str, int]


class _FunctionCaller:
def __init__(self, fault_tolerant: bool):
self.fault_tolerant = fault_tolerant
self.fault_counts: DefaultDict[str, int] = DefaultDict(int)

def __call__(self, f: LabelingFunction, x: DataPoint) -> int:
if not self.fault_tolerant:
return f(x)
try:
return f(x)
except Exception:
self.fault_counts[f.name] += 1
return -1


class BaseLFApplier:
"""Base class for LF applier objects.

Expand Down Expand Up @@ -60,7 +82,7 @@ def __repr__(self) -> str:


def apply_lfs_to_data_point(
x: DataPoint, index: int, lfs: List[LabelingFunction]
x: DataPoint, index: int, lfs: List[LabelingFunction], f_caller: _FunctionCaller
) -> RowData:
"""Label a single data point with a set of LFs.

Expand All @@ -72,6 +94,8 @@ def apply_lfs_to_data_point(
Index of the data point
lfs
Set of LFs to label ``x`` with
f_caller
A ``_FunctionCaller`` to record failed LF executions

Returns
-------
Expand All @@ -80,7 +104,7 @@ def apply_lfs_to_data_point(
"""
labels = []
for j, lf in enumerate(lfs):
y = lf(x)
y = f_caller(lf, x)
if y >= 0:
labels.append((index, j, y))
return labels
Expand Down Expand Up @@ -114,8 +138,12 @@ class LFApplier(BaseLFApplier):
"""

def apply(
self, data_points: Union[DataPoints, np.ndarray], progress_bar: bool = True
) -> np.ndarray:
self,
data_points: Union[DataPoints, np.ndarray],
progress_bar: bool = True,
fault_tolerant: bool = False,
return_meta: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
"""Label list of data points or a NumPy array with LFs.

Parameters
Expand All @@ -124,13 +152,23 @@ def apply(
List of data points or NumPy array to be labeled by LFs
progress_bar
Display a progress bar?
fault_tolerant
Output ``-1`` if LF execution fails?
return_meta
Return metadata from apply call?

Returns
-------
np.ndarray
Matrix of labels emitted by LFs
ApplierMetadata
Metadata, such as fault counts, for the apply call
"""
labels = []
f_caller = _FunctionCaller(fault_tolerant)
for i, x in tqdm(enumerate(data_points), disable=(not progress_bar)):
labels.append(apply_lfs_to_data_point(x, i, self._lfs))
return self._numpy_from_row_data(labels)
labels.append(apply_lfs_to_data_point(x, i, self._lfs, f_caller))
L = self._numpy_from_row_data(labels)
if return_meta:
return L, ApplierMetadata(f_caller.fault_counts)
return L
24 changes: 19 additions & 5 deletions snorkel/labeling/apply/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dask import dataframe as dd
from dask.distributed import Client

from .core import BaseLFApplier
from .core import BaseLFApplier, _FunctionCaller
from .pandas import apply_lfs_to_data_point, rows_to_triplets

Scheduler = Union[str, Client]
Expand All @@ -20,7 +20,12 @@ class DaskLFApplier(BaseLFApplier):
For more information, see https://docs.dask.org/en/stable/dataframe.html
"""

def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
def apply(
self,
df: dd.DataFrame,
scheduler: Scheduler = "processes",
fault_tolerant: bool = False,
) -> np.ndarray:
"""Label Dask DataFrame of data points with LFs.

Parameters
Expand All @@ -31,13 +36,16 @@ def apply(self, df: dd, scheduler: Scheduler = "processes") -> np.ndarray:
A Dask scheduling configuration: either a string option or
a ``Client``. For more information, see
https://docs.dask.org/en/stable/scheduling.html#
fault_tolerant
Output ``-1`` if LF execution fails?

Returns
-------
np.ndarray
Matrix of labels emitted by LFs
"""
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
f_caller = _FunctionCaller(fault_tolerant)
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
map_fn = df.map_partitions(lambda p_df: p_df.apply(apply_fn, axis=1))
labels = map_fn.compute(scheduler=scheduler)
labels_with_index = rows_to_triplets(labels)
Expand All @@ -52,7 +60,11 @@ class PandasParallelLFApplier(DaskLFApplier):
"""

def apply( # type: ignore
self, df: pd.DataFrame, n_parallel: int = 2, scheduler: Scheduler = "processes"
self,
df: pd.DataFrame,
n_parallel: int = 2,
scheduler: Scheduler = "processes",
fault_tolerant: bool = False,
) -> np.ndarray:
"""Label Pandas DataFrame of data points with LFs in parallel using Dask.

Expand All @@ -69,6 +81,8 @@ def apply( # type: ignore
A Dask scheduling configuration: either a string option or
a ``Client``. For more information, see
https://docs.dask.org/en/stable/scheduling.html#
fault_tolerant
Output ``-1`` if LF execution fails?

Returns
-------
Expand All @@ -81,4 +95,4 @@ def apply( # type: ignore
"For single process Pandas, use PandasLFApplier."
)
df = dd.from_pandas(df, npartitions=n_parallel)
return super().apply(df, scheduler=scheduler)
return super().apply(df, scheduler=scheduler, fault_tolerant=fault_tolerant)
34 changes: 27 additions & 7 deletions snorkel/labeling/apply/pandas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import List, Tuple
from typing import List, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -8,12 +8,14 @@
from snorkel.labeling.lf import LabelingFunction
from snorkel.types import DataPoint

from .core import BaseLFApplier, RowData
from .core import ApplierMetadata, BaseLFApplier, RowData, _FunctionCaller

PandasRowData = List[Tuple[int, int]]


def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> PandasRowData:
def apply_lfs_to_data_point(
x: DataPoint, lfs: List[LabelingFunction], f_caller: _FunctionCaller
) -> PandasRowData:
"""Label a single data point with a set of LFs.

Parameters
Expand All @@ -22,6 +24,8 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
Data point to label
lfs
Set of LFs to label ``x`` with
f_caller
A ``_FunctionCaller`` to record failed LF executions

Returns
-------
Expand All @@ -30,7 +34,7 @@ def apply_lfs_to_data_point(x: DataPoint, lfs: List[LabelingFunction]) -> Pandas
"""
labels = []
for j, lf in enumerate(lfs):
y = lf(x)
y = f_caller(lf, x)
if y >= 0:
labels.append((j, y))
return labels
Expand Down Expand Up @@ -68,7 +72,13 @@ class PandasLFApplier(BaseLFApplier):
array([[0], [1]])
"""

def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
def apply(
self,
df: pd.DataFrame,
progress_bar: bool = True,
fault_tolerant: bool = False,
return_meta: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
"""Label Pandas DataFrame of data points with LFs.

Parameters
Expand All @@ -77,17 +87,27 @@ def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> np.ndarray:
Pandas DataFrame containing data points to be labeled by LFs
progress_bar
Display a progress bar?
fault_tolerant
Output ``-1`` if LF execution fails?
return_meta
Return metadata from apply call?

Returns
-------
np.ndarray
Matrix of labels emitted by LFs
ApplierMetadata
Metadata, such as fault counts, for the apply call
"""
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs)
f_caller = _FunctionCaller(fault_tolerant)
apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
call_fn = df.apply
if progress_bar:
tqdm.pandas()
call_fn = df.progress_apply
labels = call_fn(apply_fn, axis=1)
labels_with_index = rows_to_triplets(labels)
return self._numpy_from_row_data(labels_with_index)
L = self._numpy_from_row_data(labels_with_index)
if return_meta:
return L, ApplierMetadata(f_caller.fault_counts)
return L
9 changes: 6 additions & 3 deletions snorkel/labeling/apply/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from snorkel.types import DataPoint

from .core import BaseLFApplier, RowData, apply_lfs_to_data_point
from .core import BaseLFApplier, RowData, _FunctionCaller, apply_lfs_to_data_point


class SparkLFApplier(BaseLFApplier):
Expand All @@ -18,22 +18,25 @@ class SparkLFApplier(BaseLFApplier):
``test/labeling/apply/lf_applier_spark_test_script.py``.
"""

def apply(self, data_points: RDD) -> np.ndarray:
def apply(self, data_points: RDD, fault_tolerant: bool = False) -> np.ndarray:
"""Label PySpark RDD of data points with LFs.

Parameters
----------
data_points
PySpark RDD containing data points to be labeled by LFs
fault_tolerant
Output ``-1`` if LF execution fails?

Returns
-------
np.ndarray
Matrix of labels emitted by LFs
"""
f_caller = _FunctionCaller(fault_tolerant)

def map_fn(args: Tuple[DataPoint, int]) -> RowData:
return apply_lfs_to_data_point(*args, lfs=self._lfs)
return apply_lfs_to_data_point(*args, lfs=self._lfs, f_caller=f_caller)

labels = data_points.zipWithIndex().map(map_fn).collect()
return self._numpy_from_row_data(labels)
27 changes: 2 additions & 25 deletions snorkel/labeling/lf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class LabelingFunction:
Labeling resources passed in to ``f`` via ``kwargs``
pre
Preprocessors to run on data points before LF execution
fault_tolerant
Output ``-1`` if LF execution fails?

Raises
------
Expand All @@ -39,8 +37,6 @@ class LabelingFunction:
----------
name
See above
fault_tolerant
See above
"""

def __init__(
Expand All @@ -49,10 +45,8 @@ def __init__(
f: Callable[..., int],
resources: Optional[Mapping[str, Any]] = None,
pre: Optional[List[BasePreprocessor]] = None,
fault_tolerant: bool = False,
) -> None:
self.name = name
self.fault_tolerant = fault_tolerant
self._f = f
self._resources = resources or {}
self._pre = pre or []
Expand All @@ -67,9 +61,7 @@ def _preprocess_data_point(self, x: DataPoint) -> DataPoint:
def __call__(self, x: DataPoint) -> int:
"""Label data point.

Runs all preprocessors, then passes to LF. If an exception
is encountered and the LF is in fault tolerant mode,
the LF abstains from voting.
Runs all preprocessors, then passes preprocessed data point to LF.

Parameters
----------
Expand All @@ -82,11 +74,6 @@ def __call__(self, x: DataPoint) -> int:
Label for data point
"""
x = self._preprocess_data_point(x)
if self.fault_tolerant:
try:
return self._f(x, **self._resources)
except Exception:
return -1
return self._f(x, **self._resources)

def __repr__(self) -> str:
Expand All @@ -105,8 +92,6 @@ class labeling_function:
Labeling resources passed in to ``f`` via ``kwargs``
preprocessors
Preprocessors to run on data points before LF execution
fault_tolerant
Output ``-1`` if LF execution fails?

Examples
--------
Expand All @@ -132,14 +117,12 @@ def __init__(
name: Optional[str] = None,
resources: Optional[Mapping[str, Any]] = None,
pre: Optional[List[BasePreprocessor]] = None,
fault_tolerant: bool = False,
) -> None:
if callable(name):
raise ValueError("Looks like this decorator is missing parentheses!")
self.name = name
self.resources = resources
self.pre = pre
self.fault_tolerant = fault_tolerant

def __call__(self, f: Callable[..., int]) -> LabelingFunction:
"""Wrap a function to create a ``LabelingFunction``.
Expand All @@ -155,10 +138,4 @@ def __call__(self, f: Callable[..., int]) -> LabelingFunction:
New ``LabelingFunction`` executing logic in wrapped function
"""
name = self.name or f.__name__
return LabelingFunction(
name=name,
f=f,
resources=self.resources,
pre=self.pre,
fault_tolerant=self.fault_tolerant,
)
return LabelingFunction(name=name, f=f, resources=self.resources, pre=self.pre)
Loading