Skip to content

Commit c469258

Browse files
authored
Add GPU option for spaCy preprocessor (#1523)
1 parent e316d57 commit c469258

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

snorkel/labeling/lf/nlp.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class SpacyPreprocessorParameters(NamedTuple):
1818
disable: Optional[List[str]]
1919
pre: List[BasePreprocessor]
2020
memoize: bool
21+
gpu: bool
2122

2223

2324
class SpacyPreprocessorConfig(NamedTuple):
@@ -47,6 +48,7 @@ def _create_or_check_preprocessor(
4748
disable: Optional[List[str]],
4849
pre: List[BasePreprocessor],
4950
memoize: bool,
51+
gpu: bool,
5052
) -> None:
5153
# Create a SpacyPreprocessor if one has not yet been instantiated.
5254
# Otherwise, check that configuration matches already instantiated one.
@@ -57,6 +59,7 @@ def _create_or_check_preprocessor(
5759
disable=disable,
5860
pre=pre,
5961
memoize=memoize,
62+
gpu=gpu,
6063
)
6164
if not hasattr(cls, "_nlp_config"):
6265
nlp = cls._create_preprocessor(parameters)
@@ -78,9 +81,10 @@ def __init__(
7881
language: str = EN_CORE_WEB_SM,
7982
disable: Optional[List[str]] = None,
8083
memoize: bool = True,
84+
gpu: bool = False,
8185
) -> None:
8286
self._create_or_check_preprocessor(
83-
text_field, doc_field, language, disable, pre or [], memoize
87+
text_field, doc_field, language, disable, pre or [], memoize, gpu
8488
)
8589
super().__init__(name, f, resources=resources, pre=[self._nlp_config.nlp])
8690

@@ -128,6 +132,8 @@ class NLPLabelingFunction(BaseNLPLabelingFunction):
128132
See https://spacy.io/usage/processing-pipelines#disabling
129133
memoize
130134
Memoize preprocessor outputs?
135+
gpu
136+
Prefer Spacy GPU processing?
131137
132138
Raises
133139
------
@@ -176,13 +182,15 @@ def __init__(
176182
language: str = EN_CORE_WEB_SM,
177183
disable: Optional[List[str]] = None,
178184
memoize: bool = True,
185+
gpu: bool = False,
179186
) -> None:
180187
super().__init__(name, resources, pre)
181188
self.text_field = text_field
182189
self.doc_field = doc_field
183190
self.language = language
184191
self.disable = disable
185192
self.memoize = memoize
193+
self.gpu = gpu
186194

187195
def __call__(self, f: Callable[..., int]) -> BaseNLPLabelingFunction:
188196
"""Wrap a function to create an ``BaseNLPLabelingFunction``.
@@ -210,6 +218,7 @@ def __call__(self, f: Callable[..., int]) -> BaseNLPLabelingFunction:
210218
language=self.language,
211219
disable=self.disable,
212220
memoize=self.memoize,
221+
gpu=self.gpu,
213222
)
214223

215224

snorkel/labeling/lf/nlp_spark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class SparkNLPLabelingFunction(BaseNLPLabelingFunction):
3636
See https://spacy.io/usage/processing-pipelines#disabling
3737
memoize
3838
Memoize preprocessor outputs?
39+
gpu
40+
Prefer Spacy GPU processing?
3941
4042
Raises
4143
------

snorkel/preprocess/nlp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class SpacyPreprocessor(Preprocessor):
4040
Preprocessors to run before this preprocessor is executed
4141
memoize
4242
Memoize preprocessor outputs?
43+
gpu
44+
Prefer Spacy GPU processing?
4345
"""
4446

4547
def __init__(
@@ -50,6 +52,7 @@ def __init__(
5052
disable: Optional[List[str]] = None,
5153
pre: Optional[List[BasePreprocessor]] = None,
5254
memoize: bool = False,
55+
gpu: bool = False,
5356
) -> None:
5457
name = type(self).__name__
5558
super().__init__(
@@ -59,6 +62,9 @@ def __init__(
5962
pre=pre,
6063
memoize=memoize,
6164
)
65+
self.gpu = gpu
66+
if self.gpu:
67+
spacy.prefer_gpu()
6268
self._nlp = spacy.load(language, disable=disable or [])
6369

6470
def run(self, text: str) -> FieldMap: # type: ignore

0 commit comments

Comments
 (0)