Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion snorkel/labeling/lf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class labeling_function:
Name of the LF
resources
Labeling resources passed in to ``f`` via ``kwargs``
preprocessors
pre
Preprocessors to run on data points before LF execution

Examples
Expand Down
22 changes: 20 additions & 2 deletions snorkel/labeling/lf/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from snorkel.preprocess import BasePreprocessor
from snorkel.preprocess.nlp import EN_CORE_WEB_SM, SpacyPreprocessor
from snorkel.types import HashingFunction

from .core import LabelingFunction, labeling_function

Expand All @@ -18,6 +19,7 @@ class SpacyPreprocessorParameters(NamedTuple):
disable: Optional[List[str]]
pre: List[BasePreprocessor]
memoize: bool
memoize_key: Optional[HashingFunction]
gpu: bool


Expand Down Expand Up @@ -48,6 +50,7 @@ def _create_or_check_preprocessor(
disable: Optional[List[str]],
pre: List[BasePreprocessor],
memoize: bool,
memoize_key: Optional[HashingFunction],
gpu: bool,
) -> None:
# Create a SpacyPreprocessor if one has not yet been instantiated.
Expand All @@ -59,6 +62,7 @@ def _create_or_check_preprocessor(
disable=disable,
pre=pre,
memoize=memoize,
memoize_key=memoize_key,
gpu=gpu,
)
if not hasattr(cls, "_nlp_config"):
Expand All @@ -81,10 +85,18 @@ def __init__(
language: str = EN_CORE_WEB_SM,
disable: Optional[List[str]] = None,
memoize: bool = True,
memoize_key: Optional[HashingFunction] = None,
gpu: bool = False,
) -> None:
self._create_or_check_preprocessor(
text_field, doc_field, language, disable, pre or [], memoize, gpu
text_field,
doc_field,
language,
disable,
pre or [],
memoize,
memoize_key,
gpu,
)
super().__init__(name, f, resources=resources, pre=[self._nlp_config.nlp])

Expand Down Expand Up @@ -132,6 +144,8 @@ class NLPLabelingFunction(BaseNLPLabelingFunction):
See https://spacy.io/usage/processing-pipelines#disabling
memoize
Memoize preprocessor outputs?
memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
gpu
Prefer Spacy GPU processing?

Expand Down Expand Up @@ -182,6 +196,7 @@ def __init__(
language: str = EN_CORE_WEB_SM,
disable: Optional[List[str]] = None,
memoize: bool = True,
memoize_key: Optional[HashingFunction] = None,
gpu: bool = False,
) -> None:
super().__init__(name, resources, pre)
Expand All @@ -190,6 +205,7 @@ def __init__(
self.language = language
self.disable = disable
self.memoize = memoize
self.memoize_key = memoize_key
self.gpu = gpu

def __call__(self, f: Callable[..., int]) -> BaseNLPLabelingFunction:
Expand Down Expand Up @@ -218,6 +234,7 @@ def __call__(self, f: Callable[..., int]) -> BaseNLPLabelingFunction:
language=self.language,
disable=self.disable,
memoize=self.memoize,
memoize_key=self.memoize_key,
gpu=self.gpu,
)

Expand Down Expand Up @@ -245,7 +262,8 @@ class nlp_labeling_function(base_nlp_labeling_function):
See https://spacy.io/usage/processing-pipelines#disabling
memoize
Memoize preprocessor outputs?

memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)

Example
-------
Expand Down
5 changes: 4 additions & 1 deletion snorkel/labeling/lf/nlp_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class SparkNLPLabelingFunction(BaseNLPLabelingFunction):
See https://spacy.io/usage/processing-pipelines#disabling
memoize
Memoize preprocessor outputs?
memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
gpu
Prefer Spacy GPU processing?

Expand Down Expand Up @@ -82,7 +84,8 @@ class spark_nlp_labeling_function(base_nlp_labeling_function):
See https://spacy.io/usage/processing-pipelines#disabling
memoize
Memoize preprocessor outputs?

memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)

Example
-------
Expand Down
41 changes: 34 additions & 7 deletions snorkel/map/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import pandas as pd

from snorkel.types import DataPoint, FieldMap
from snorkel.types import DataPoint, FieldMap, HashingFunction

MapFunction = Callable[[DataPoint], Optional[DataPoint]]

Expand Down Expand Up @@ -94,6 +94,8 @@ class BaseMapper:
Mappers to run before this mapper is executed
memoize
Memoize mapper outputs?
memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)

Raises
------
Expand All @@ -106,9 +108,18 @@ class BaseMapper:
Memoize mapper outputs?
"""

def __init__(self, name: str, pre: List["BaseMapper"], memoize: bool) -> None:
def __init__(
self,
name: str,
pre: List["BaseMapper"],
memoize: bool,
memoize_key: Optional[HashingFunction] = None,
Copy link
Member

@henryre henryre Jun 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to prefer None as the default instead of using get_hashable as the default? We'd then be able to avoid the Optional everywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see 2 reasons :

  • function are mutable and the "good practice" is to avoid mutable in default parameters. I don't really see a situation where we would mutate this function tho.
  • we will have to import get_hashable in all the subclasses and all the wrappers of BaseMapper

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point about mutating functions in modules, hadn't thought of that route. Sounds good!

) -> None:
if memoize_key is None:
memoize_key = get_hashable
self.name = name
self._pre = pre
self._memoize_key = memoize_key
self.memoize = memoize
self.reset_cache()

Expand Down Expand Up @@ -140,7 +151,7 @@ def __call__(self, x: DataPoint) -> Optional[DataPoint]:
"""
if self.memoize:
# NB: don't do ``self._cache.get(...)`` first in case cached value is ``None``
x_hashable = get_hashable(x)
x_hashable = self._memoize_key(x)
if x_hashable in self._cache:
return self._cache[x_hashable]
# NB: using pickle roundtrip as a more robust deepcopy
Expand Down Expand Up @@ -199,6 +210,8 @@ class Mapper(BaseMapper):
Mappers to run before this mapper is executed
memoize
Memoize mapper outputs?
memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)

Raises
------
Expand All @@ -222,13 +235,14 @@ def __init__(
mapped_field_names: Optional[Mapping[str, str]] = None,
pre: Optional[List[BaseMapper]] = None,
memoize: bool = False,
memoize_key: Optional[HashingFunction] = None,
) -> None:
if field_names is None:
# Parse field names from ``run(...)`` if not provided
field_names = {k: k for k in get_parameters(self.run)[1:]}
self.field_names = field_names
self.mapped_field_names = mapped_field_names
super().__init__(name, pre or [], memoize)
super().__init__(name, pre or [], memoize, memoize_key)

def run(self, **kwargs: Any) -> Optional[FieldMap]:
"""Run the mapping operation using the input fields.
Expand Down Expand Up @@ -280,14 +294,16 @@ class LambdaMapper(BaseMapper):

Parameters
----------
name:
name
Name of mapper
f
Function executing the mapping operation
pre
Mappers to run before this mapper is executed
memoize
Memoize mapper outputs?
memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
"""

def __init__(
Expand All @@ -296,9 +312,10 @@ def __init__(
f: MapFunction,
pre: Optional[List[BaseMapper]] = None,
memoize: bool = False,
memoize_key: Optional[HashingFunction] = None,
) -> None:
self._f = f
super().__init__(name, pre or [], memoize)
super().__init__(name, pre or [], memoize, memoize_key)

def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
return self._f(x)
Expand Down Expand Up @@ -328,6 +345,8 @@ class lambda_mapper:
Mappers to run before this mapper is executed
memoize
Memoize mapper outputs?
memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)

Attributes
----------
Expand All @@ -340,12 +359,14 @@ def __init__(
name: Optional[str] = None,
pre: Optional[List[BaseMapper]] = None,
memoize: bool = False,
memoize_key: Optional[HashingFunction] = None,
) -> None:
if callable(name):
raise ValueError("Looks like this decorator is missing parentheses!")
self.name = name
self.pre = pre
self.memoize = memoize
self.memoize_key = memoize_key

def __call__(self, f: MapFunction) -> LambdaMapper:
"""Wrap a function to create a ``LambdaMapper``.
Expand All @@ -361,4 +382,10 @@ def __call__(self, f: MapFunction) -> LambdaMapper:
New ``LambdaMapper`` executing operation in wrapped function
"""
name = self.name or f.__name__
return LambdaMapper(name=name, f=f, pre=self.pre, memoize=self.memoize)
return LambdaMapper(
name=name,
f=f,
pre=self.pre,
memoize=self.memoize,
memoize_key=self.memoize_key,
)
2 changes: 1 addition & 1 deletion snorkel/preprocess/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Preprocessor(Mapper):


class LambdaPreprocessor(LambdaMapper):
"""Convenience class for definining preprocessors from functions.
"""Convenience class for defining preprocessors from functions.

See ``snorkel.map.core.LambdaMapper`` for details.
"""
Expand Down
6 changes: 5 additions & 1 deletion snorkel/preprocess/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import spacy

from snorkel.types import FieldMap
from snorkel.types import FieldMap, HashingFunction

from .core import BasePreprocessor, Preprocessor

Expand Down Expand Up @@ -40,6 +40,8 @@ class SpacyPreprocessor(Preprocessor):
Preprocessors to run before this preprocessor is executed
memoize
Memoize preprocessor outputs?
memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
gpu
Prefer Spacy GPU processing?
"""
Expand All @@ -52,6 +54,7 @@ def __init__(
disable: Optional[List[str]] = None,
pre: Optional[List[BasePreprocessor]] = None,
memoize: bool = False,
memoize_key: Optional[HashingFunction] = None,
gpu: bool = False,
) -> None:
name = type(self).__name__
Expand All @@ -61,6 +64,7 @@ def __init__(
mapped_field_names=dict(doc=doc_field),
pre=pre,
memoize=memoize,
memoize_key=memoize_key,
)
self.gpu = gpu
if self.gpu:
Expand Down
2 changes: 2 additions & 0 deletions snorkel/slicing/sf/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class NLPSlicingFunction(BaseNLPLabelingFunction):
See https://spacy.io/usage/processing-pipelines#disabling
memoize
Memoize preprocessor outputs?
memoize_key
Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)

Raises
------
Expand Down
1 change: 1 addition & 0 deletions snorkel/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .classifier import Config # noqa: F401
from .data import DataPoint, DataPoints, Field, FieldMap # noqa: F401
from .hashing import HashingFunction # noqa: F401
4 changes: 4 additions & 0 deletions snorkel/types/hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from collections import Hashable
from typing import Any, Callable

HashingFunction = Callable[[Any], Hashable]
23 changes: 23 additions & 0 deletions test/map/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,29 @@ def square(x: DataPoint) -> DataPoint:
self.assertIsNone(x21_mapped)
self.assertEqual(square_hit_tracker.n_hits, 1)

def test_decorator_mapper_memoized_use_memoize_key(self) -> None:
square_hit_tracker = SquareHitTracker()

@lambda_mapper(memoize=True, memoize_key=lambda x: x.uid)
def square(x: DataPoint) -> DataPoint:
x.num_squared = square_hit_tracker(x.num)
return x

x1 = SimpleNamespace(
uid="id1", num=8, not_used=0, unhashable=pd.DataFrame({"value": [5]})
)
x1_mapped = square(x1)
assert x1_mapped is not None
self.assertEqual(x1_mapped.num_squared, 64)
self.assertEqual(square_hit_tracker.n_hits, 1)
x2 = SimpleNamespace(
uid="id1", num=8, not_used=1, unhashable=pd.DataFrame({"value": [5]})
)
x2_mapped = square(x2)
assert x2_mapped is not None
self.assertEqual(x2_mapped.num_squared, 64)
self.assertEqual(square_hit_tracker.n_hits, 1)

def test_decorator_mapper_not_memoized(self) -> None:
square_hit_tracker = SquareHitTracker()

Expand Down