@@ -18,6 +18,7 @@ class SpacyPreprocessorParameters(NamedTuple):
18
18
disable : Optional [List [str ]]
19
19
pre : List [BasePreprocessor ]
20
20
memoize : bool
21
+ gpu : bool
21
22
22
23
23
24
class SpacyPreprocessorConfig (NamedTuple ):
@@ -47,6 +48,7 @@ def _create_or_check_preprocessor(
47
48
disable : Optional [List [str ]],
48
49
pre : List [BasePreprocessor ],
49
50
memoize : bool ,
51
+ gpu : bool ,
50
52
) -> None :
51
53
# Create a SpacyPreprocessor if one has not yet been instantiated.
52
54
# Otherwise, check that configuration matches already instantiated one.
@@ -57,6 +59,7 @@ def _create_or_check_preprocessor(
57
59
disable = disable ,
58
60
pre = pre ,
59
61
memoize = memoize ,
62
+ gpu = gpu ,
60
63
)
61
64
if not hasattr (cls , "_nlp_config" ):
62
65
nlp = cls ._create_preprocessor (parameters )
@@ -78,9 +81,10 @@ def __init__(
78
81
language : str = EN_CORE_WEB_SM ,
79
82
disable : Optional [List [str ]] = None ,
80
83
memoize : bool = True ,
84
+ gpu : bool = False ,
81
85
) -> None :
82
86
self ._create_or_check_preprocessor (
83
- text_field , doc_field , language , disable , pre or [], memoize
87
+ text_field , doc_field , language , disable , pre or [], memoize , gpu
84
88
)
85
89
super ().__init__ (name , f , resources = resources , pre = [self ._nlp_config .nlp ])
86
90
@@ -128,6 +132,8 @@ class NLPLabelingFunction(BaseNLPLabelingFunction):
128
132
See https://spacy.io/usage/processing-pipelines#disabling
129
133
memoize
130
134
Memoize preprocessor outputs?
135
+ gpu
136
+ Prefer Spacy GPU processing?
131
137
132
138
Raises
133
139
------
@@ -176,13 +182,15 @@ def __init__(
176
182
language : str = EN_CORE_WEB_SM ,
177
183
disable : Optional [List [str ]] = None ,
178
184
memoize : bool = True ,
185
+ gpu : bool = False ,
179
186
) -> None :
180
187
super ().__init__ (name , resources , pre )
181
188
self .text_field = text_field
182
189
self .doc_field = doc_field
183
190
self .language = language
184
191
self .disable = disable
185
192
self .memoize = memoize
193
+ self .gpu = gpu
186
194
187
195
def __call__ (self , f : Callable [..., int ]) -> BaseNLPLabelingFunction :
188
196
"""Wrap a function to create an ``BaseNLPLabelingFunction``.
@@ -210,6 +218,7 @@ def __call__(self, f: Callable[..., int]) -> BaseNLPLabelingFunction:
210
218
language = self .language ,
211
219
disable = self .disable ,
212
220
memoize = self .memoize ,
221
+ gpu = self .gpu ,
213
222
)
214
223
215
224
0 commit comments