Skip to content

Commit b80070c

Browse files
committed
Fix initialize and update tests
1 parent 7586325 commit b80070c

File tree

4 files changed

+54
-23
lines changed

4 files changed

+54
-23
lines changed

sense2vec/component.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Tuple, Union, List, Dict
1+
from typing import Tuple, Union, List, Dict, Callable, Iterable, Optional
22
from spacy.language import Language
33
from spacy.tokens import Doc, Token, Span
4+
from spacy.training import Example
45
from spacy.vocab import Vocab
56
from spacy.util import SimpleFrozenDict
67
from pathlib import Path
@@ -215,6 +216,24 @@ def s2v_other_senses(self, obj: Union[Token, Span]) -> List[str]:
215216
key = self.s2v_key(obj)
216217
return obj.doc._._s2v.get_other_senses(key)
217218

219+
def initialize(
220+
self,
221+
get_examples: Callable[[], Iterable[Example]],
222+
*,
223+
nlp: Optional[Language] = None,
224+
data_path: Optional[str] = None
225+
):
226+
"""Initialize the component and load in data. Can be used to add the
227+
component with vectors to a pipeline before training.
228+
229+
get_examples (Callable[[], Iterable[Example]]): Function that
230+
returns a representative sample of gold-standard Example objects.
231+
nlp (Language): The current nlp object the component is part of.
232+
data_path (Optional[str]): Optional path to sense2vec model.
233+
"""
234+
if data_path is not None:
235+
self.from_disk(data_path)
236+
218237
def to_bytes(self) -> bytes:
219238
"""Serialize the component to a bytestring.
220239

sense2vec/sense2vec.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
from typing import Tuple, List, Union, Sequence, Dict, Callable, Any, Iterable
2-
from typing import Optional
1+
from typing import Tuple, List, Union, Sequence, Dict, Callable, Any
32
from pathlib import Path
4-
from spacy.language import Language
53
from spacy.vectors import Vectors
64
from spacy.strings import StringStore
75
from spacy.util import SimpleFrozenDict
@@ -297,24 +295,6 @@ def to_bytes(self, exclude: Sequence[str] = tuple()) -> bytes:
297295
data["cache"] = self.cache
298296
return srsly.msgpack_dumps(data)
299297

300-
def initialize(
301-
self,
302-
get_examples: Callable[[], Iterable],
303-
*,
304-
nlp: Optional[Language] = None,
305-
data_path: Optional[str] = None
306-
):
307-
"""Initialize the component and load in data. Can be used to add the
308-
component with vectors to a pipeline before training.
309-
310-
get_examples (Callable[[], Iterable[Example]]): Function that
311-
returns a representative sample of gold-standard Example objects.
312-
nlp (Language): The current nlp object the component is part of.
313-
data_path (Optional[str]): Optional path to sense2vec model.
314-
"""
315-
if data_path is not None:
316-
self.from_disk(data_path)
317-
318298
def from_bytes(self, bytes_data: bytes, exclude: Sequence[str] = tuple()):
319299
"""Load a Sense2Vec object from a bytestring.
320300

sense2vec/tests/test_component.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pytest
22
import numpy
3+
import spacy
34
from spacy.vocab import Vocab
45
from spacy.tokens import Doc, Span
56
from sense2vec import Sense2VecComponent
7+
from pathlib import Path
68

79

810
@pytest.fixture
@@ -103,3 +105,33 @@ def test_component_to_from_bytes(doc):
103105
assert doc[0]._.in_s2v is False
104106
new_doc = new_s2v(doc)
105107
assert new_doc[0]._.in_s2v is True
108+
109+
110+
def test_component_initialize():
111+
data_path = Path(__file__).parent / "data"
112+
# With from_disk
113+
nlp = spacy.blank("en")
114+
s2v = nlp.add_pipe("sense2vec")
115+
if Doc.has_extension("s2v_phrases"):
116+
s2v.first_run = False # don't set up extensions again
117+
s2v.from_disk(data_path)
118+
doc = Doc(nlp.vocab, words=["beekeepers"], pos=["NOUN"])
119+
s2v(doc)
120+
assert doc[0]._.s2v_key == "beekeepers|NOUN"
121+
most_similar = [item for item, score in doc[0]._.s2v_most_similar(2)]
122+
assert most_similar[0] == ("honey bees", "NOUN")
123+
assert most_similar[1] == ("Beekeepers", "NOUN")
124+
125+
# With initialize
126+
nlp = spacy.blank("en")
127+
s2v = nlp.add_pipe("sense2vec")
128+
s2v.first_run = False # don't set up extensions again
129+
init_cfg = {"sense2vec": {"data_path": str(data_path)}}
130+
nlp.config["initialize"]["components"] = init_cfg
131+
nlp.initialize()
132+
doc = Doc(nlp.vocab, words=["beekeepers"], pos=["NOUN"])
133+
s2v(doc)
134+
assert doc[0]._.s2v_key == "beekeepers|NOUN"
135+
most_similar = [item for item, score in doc[0]._.s2v_most_similar(2)]
136+
assert most_similar[0] == ("honey bees", "NOUN")
137+
assert most_similar[1] == ("Beekeepers", "NOUN")

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ install_requires =
3636

3737
[options.entry_points]
3838
spacy_factories =
39-
sense2vec = sense2vec:make_sense2vec
39+
sense2vec = sense2vec:component.make_sense2vec
4040
prodigy_recipes =
4141
sense2vec.teach = sense2vec:prodigy_recipes.teach
4242
sens2vec.to-patterns = sense2vec:prodigy_recipes.to_patterns

0 commit comments

Comments
 (0)