Skip to content

Commit 01b67ef

Browse files
KennethEnevoldsenKennethEnevoldsen
andauthored
added transformers_config for passing arguments to the transformer (#268)
* added transformers config * changed def config to include new transformers config * fixed quotationmarks in config * removed wierd symbol * added attention in data classes * fixed keyerror * ibid * added pass of config to forward * ibid * fix for init * fixed tensors in forward * removed default for attention and added to_doc fix for attn * reformatted to black (accidentally reformatted via. autopep8) * added def to transformerdata * bugfixes - don't get why this does not use the default argument here though * removed default trfconfig from trfmodel * updated dummy transformer * fixed tests * added Tok2VecTransformer.v2 * changed typing * fixed type hint * fixed name of transformer_tok2vec_v2, added def * fixed default config to match name change * remove ds * renamed transformers_config Co-authored-by: KennethEnevoldsen <[email protected]>
1 parent 21e80a6 commit 01b67ef

File tree

7 files changed

+99
-24
lines changed

7 files changed

+99
-24
lines changed

spacy_transformers/architectures.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def transformer_listener_tok2vec_v1(
1313
) -> Model[List[Doc], List[Floats2d]]:
1414
"""Create a 'TransformerListener' layer, which will connect to a Transformer
1515
component earlier in the pipeline.
16-
16+
1717
The layer takes a list of Doc objects as input, and produces a list of
1818
2d arrays as output, with each array having one row per token. Most spaCy
1919
models expect a sublayer with this signature, making it easy to connect them
@@ -46,7 +46,7 @@ def transformer_listener_tok2vec_v1(
4646
def transformer_tok2vec_v1(
4747
name: str,
4848
get_spans,
49-
tokenizer_config,
49+
tokenizer_config: dict,
5050
pooling: Model[Ragged, Floats2d],
5151
grad_factor: float = 1.0,
5252
) -> Model[List[Doc], List[Floats2d]]:
@@ -74,6 +74,42 @@ def transformer_tok2vec_v1(
7474
)
7575

7676

77+
@registry.architectures.register("spacy-transformers.Tok2VecTransformer.v2")
78+
def transformer_tok2vec_v2(
79+
name: str,
80+
get_spans,
81+
tokenizer_config: dict,
82+
transformer_config: dict,
83+
pooling: Model[Ragged, Floats2d],
84+
grad_factor: float = 1.0,
85+
) -> Model[List[Doc], List[Floats2d]]:
86+
"""Use a transformer as a "Tok2Vec" layer directly. This does not allow
87+
multiple components to share the transformer weights, and does not allow
88+
the transformer to set annotations into the `Doc` object, but it's a
89+
simpler solution if you only need the transformer within one component.
90+
91+
get_spans (Callable[[List[Doc]], List[List[Span]]]): A function to extract
92+
spans from the batch of Doc objects. See the "TransformerModel" layer
93+
for details.
94+
tokenizer_config (dict): Settings to pass to the transformers tokenizer.
95+
transformers_config (dict): Settings to pass to the transformers forward pass
96+
of the transformer.
97+
pooling (Model[Ragged, Floats2d]): A reduction layer used to calculate
98+
the token vectors based on zero or more wordpiece vectors. If in doubt,
99+
mean pooling (see `thinc.layers.reduce_mean`) is usually a good choice.
100+
grad_factor (float): Reweight gradients from the component before passing
101+
them to the transformer. You can set this to 0 to "freeze" the transformer
102+
weights with respect to the component, or to make it learn more slowly.
103+
Leaving it at 1.0 is usually fine.
104+
"""
105+
return chain(
106+
TransformerModel(name, get_spans, tokenizer_config, transformer_config),
107+
split_trf_batch(),
108+
trfs2arrays(pooling, grad_factor),
109+
)
110+
111+
112+
77113
registry.architectures.register(
78114
"spacy-transformers.TransformerModel.v1", func=TransformerModel
79115
)

spacy_transformers/data_classes.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Optional, List, Dict, Any
2+
from typing import Optional, List, Dict, Any, Tuple
33
import torch
44
import numpy
55
from transformers.tokenization_utils import BatchEncoding
@@ -155,11 +155,14 @@ class TransformerData:
155155
wordpieces: WordpieceBatch
156156
tensors: List[FloatsXd]
157157
align: Ragged
158+
attention: Optional[Tuple[FloatsXd, ...]] = None
158159

159160
@classmethod
160161
def empty(cls) -> "TransformerData":
161162
align = Ragged(numpy.zeros((0,), dtype="i"), numpy.zeros((0,), dtype="i"))
162-
return cls(wordpieces=WordpieceBatch.empty(), tensors=[], align=align)
163+
return cls(
164+
wordpieces=WordpieceBatch.empty(), tensors=[], align=align, attention=None
165+
)
163166

164167
@classmethod
165168
def zeros(cls, length: int, width: int, *, xp=numpy) -> "TransformerData":
@@ -247,6 +250,7 @@ class FullTransformerBatch:
247250
wordpieces: WordpieceBatch
248251
tensors: List[torch.Tensor]
249252
align: Ragged
253+
attention: Optional[Tuple[torch.Tensor]] = None
250254
cached_doc_data: Optional[List[TransformerData]] = None
251255

252256
@classmethod
@@ -259,6 +263,7 @@ def empty(cls, nr_docs) -> "FullTransformerBatch":
259263
wordpieces=WordpieceBatch.empty(),
260264
tensors=[],
261265
align=align,
266+
attention=None,
262267
cached_doc_data=doc_data,
263268
)
264269

@@ -312,11 +317,16 @@ def split_by_doc(self) -> List[TransformerData]:
312317
doc_tokens = self.wordpieces[start:end]
313318
doc_align = self.align[start_i:end_i]
314319
doc_align.data = doc_align.data - prev_tokens
320+
if self.attention:
321+
attn = [torch2xp(t[start:end]) for t in self.attention]
322+
else:
323+
attn = None
315324
outputs.append(
316325
TransformerData(
317326
wordpieces=doc_tokens,
318327
tensors=[torch2xp(t[start:end]) for t in self.tensors],
319328
align=doc_align,
329+
attention=attn,
320330
)
321331
)
322332
prev_tokens += doc_tokens.input_ids.size

spacy_transformers/layers/transformer_model.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
def TransformerModel(
18-
name: str, get_spans: Callable, tokenizer_config: dict
18+
name: str, get_spans: Callable, tokenizer_config: dict = {}, transformer_config: dict = {}
1919
) -> Model[List[Doc], FullTransformerBatch]:
2020
"""
2121
get_spans (Callable[[List[Doc]], List[Span]]):
@@ -25,6 +25,7 @@ def TransformerModel(
2525
overlap, and you can also omit sections of the Doc if they are not
2626
relevant.
2727
tokenizer_config (dict): Settings to pass to the transformers tokenizer.
28+
transformer_config (dict): Settings to pass to the transformers forward pass.
2829
"""
2930

3031
return Model(
@@ -38,6 +39,7 @@ def TransformerModel(
3839
"get_spans": get_spans,
3940
"name": name,
4041
"tokenizer_config": tokenizer_config,
42+
"transformer_config": transformer_config,
4143
"set_transformer": set_pytorch_transformer,
4244
"has_transformer": False,
4345
"flush_cache_chance": 0.0,
@@ -75,7 +77,8 @@ def init(model: Model, X=None, Y=None):
7577
return
7678
name = model.attrs["name"]
7779
tok_cfg = model.attrs["tokenizer_config"]
78-
tokenizer, transformer = huggingface_from_pretrained(name, tok_cfg)
80+
trf_cfg = model.attrs["transformer_config"]
81+
tokenizer, transformer = huggingface_from_pretrained(name, tok_cfg, trf_cfg)
7982
model.attrs["tokenizer"] = tokenizer
8083
model.attrs["set_transformer"](model, transformer)
8184
# Call the model with a batch of inputs to infer the width
@@ -89,26 +92,23 @@ def init(model: Model, X=None, Y=None):
8992
for doc_spans in nested_spans:
9093
flat_spans.extend(doc_spans)
9194
token_data = huggingface_tokenize(
92-
model.attrs["tokenizer"],
93-
[span.text for span in flat_spans]
95+
model.attrs["tokenizer"], [span.text for span in flat_spans]
9496
)
9597
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
9698
align = get_alignment(
97-
flat_spans,
98-
wordpieces.strings, model.attrs["tokenizer"].all_special_tokens
99+
flat_spans, wordpieces.strings, model.attrs["tokenizer"].all_special_tokens
99100
)
100101
wordpieces, align = truncate_oversize_splits(
101102
wordpieces, align, tokenizer.model_max_length
102103
)
103104
else:
104105
texts = ["hello world", "foo bar"]
105-
token_data = huggingface_tokenize(
106-
model.attrs["tokenizer"],
107-
texts
108-
)
106+
token_data = huggingface_tokenize(model.attrs["tokenizer"], texts)
109107
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
110108
model.layers[0].initialize(X=wordpieces)
111109
tensors = model.layers[0].predict(wordpieces)
110+
if trf_cfg["output_attentions"] is True:
111+
tensors = tensors[:-1] # remove attention
112112
t_i = find_last_hidden(tensors)
113113
model.set_dim("nO", tensors[t_i].shape[-1])
114114

@@ -118,6 +118,7 @@ def forward(
118118
) -> Tuple[FullTransformerBatch, Callable]:
119119
tokenizer = model.attrs["tokenizer"]
120120
get_spans = model.attrs["get_spans"]
121+
trf_config = model.attrs["transformer_config"]
121122
transformer = model.layers[0]
122123

123124
nested_spans = get_spans(docs)
@@ -142,8 +143,17 @@ def forward(
142143
tensors, bp_tensors = transformer(wordpieces, is_train)
143144
if "logger" in model.attrs:
144145
log_gpu_memory(model.attrs["logger"], "after forward")
146+
if ("output_attentions" in trf_config) and (trf_config["output_attentions"] is True):
147+
attn = tensors[-1]
148+
tensors = tensors[:-1]
149+
else:
150+
attn = None
145151
output = FullTransformerBatch(
146-
spans=nested_spans, wordpieces=wordpieces, tensors=tensors, align=align
152+
spans=nested_spans,
153+
wordpieces=wordpieces,
154+
tensors=tensors,
155+
align=align,
156+
attention=attn,
147157
)
148158
if "logger" in model.attrs:
149159
log_gpu_memory(model.attrs["logger"], "return from forward")

spacy_transformers/pipeline_component.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
@architectures = "spacy-transformers.TransformerModel.v1"
3131
name = "roberta-base"
3232
tokenizer_config = {"use_fast": true}
33+
transformer_config = {"output_attentions": false}
3334
3435
[transformer.model.get_spans]
3536
@span_getters = "spacy-transformers.strided_spans.v1"
@@ -143,7 +144,9 @@ def add_listener(self, listener: TransformerListener, component_name: str) -> No
143144
if self.model.has_dim("nO") and listener.has_dim("nO") is None:
144145
listener.set_dim("nO", self.model.get_dim("nO"))
145146

146-
def remove_listener(self, listener: TransformerListener, component_name: str) -> bool:
147+
def remove_listener(
148+
self, listener: TransformerListener, component_name: str
149+
) -> bool:
147150
"""Remove a listener for a downstream component. Usually internals."""
148151
if component_name in self.listener_map:
149152
if listener in self.listener_map[component_name]:
@@ -167,7 +170,10 @@ def find_listeners(self, component) -> None:
167170
names = ("*", self.name)
168171
if isinstance(getattr(component, "model", None), Model):
169172
for node in component.model.walk():
170-
if isinstance(node, TransformerListener) and node.upstream_name in names:
173+
if (
174+
isinstance(node, TransformerListener)
175+
and node.upstream_name in names
176+
):
171177
self.add_listener(node, component.name)
172178

173179
def __call__(self, doc: Doc) -> Doc:
@@ -296,7 +302,8 @@ def accumulate_gradient(d_trf_datas: List[TransformerData]):
296302
nonlocal d_tensors
297303
for i, d_trf_data in enumerate(d_trf_datas):
298304
for d_tensor in d_trf_data.tensors:
299-
losses[self.name] += float((d_tensor ** 2).sum()) # type: ignore
305+
# type: ignore
306+
losses[self.name] += float((d_tensor ** 2).sum())
300307
if i >= len(d_tensors):
301308
d_tensors.append(d_trf_data.tensors)
302309
else:
@@ -389,7 +396,7 @@ def from_disk(
389396
def load_model(p):
390397
p = Path(p).absolute()
391398
tokenizer, transformer = huggingface_from_pretrained(
392-
p, self.model.attrs["tokenizer_config"]
399+
p, self.model.attrs["tokenizer_config"], self.model.attrs["transformer_config"]
393400
)
394401
self.model.attrs["tokenizer"] = tokenizer
395402
self.model.attrs["set_transformer"](self.model, transformer)

spacy_transformers/tests/test_model_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def name(request):
2828

2929
@pytest.fixture(scope="session")
3030
def trf_model(name):
31-
model = TransformerModel(name, get_doc_spans, {"use_fast": True})
31+
model = TransformerModel(
32+
name, get_doc_spans, {"use_fast": True}, {"output_attentions": False}
33+
)
3234
model.initialize()
3335
return model
3436

spacy_transformers/tests/util.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ def _forward(model, tokens, is_train):
116116
tensors.append(torch.zeros(*shape))
117117
return tensors, lambda d_tensors: tokens
118118

119-
return Model("dummy-transformer", _forward, attrs={"width": width, "depth": depth})
119+
return Model(
120+
"dummy-transformer",
121+
_forward,
122+
attrs={"width": width, "depth": depth},
123+
)
120124

121125

122126
def DummyTransformer(
@@ -132,6 +136,7 @@ def DummyTransformer(
132136
"tokenizer": DummyTokenizer(),
133137
"grad_factor": 1.0,
134138
"flush_cache_chance": 0.0,
139+
"transformer_config": {}
135140
},
136141
dims={"nO": width},
137142
)

spacy_transformers/util.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Dict, Union
22
from pathlib import Path
3+
from functools import partial
34
import random
45
from transformers import AutoModel, AutoTokenizer
56
from transformers.tokenization_utils import BatchEncoding
@@ -16,20 +17,24 @@
1617
# fmt: on
1718

1819

19-
def huggingface_from_pretrained(source: Union[Path, str], config: Dict):
20+
def huggingface_from_pretrained(
21+
source: Union[Path, str], tok_config: Dict, trf_config: Dict
22+
):
2023
"""Create a Huggingface transformer model from pretrained weights. Will
2124
download the model if it is not already downloaded.
2225
2326
source (Union[str, Path]): The name of the model or a path to it, such as
2427
'bert-base-cased'.
25-
config (dict): Settings to pass to the tokenizer.
28+
tok_config (dict): Settings to pass to the tokenizer.
29+
trf_config (dict): Settings to pass to the transformer.
2630
"""
2731
if hasattr(source, "absolute"):
2832
str_path = str(source.absolute())
2933
else:
3034
str_path = source
31-
tokenizer = AutoTokenizer.from_pretrained(str_path, **config)
35+
tokenizer = AutoTokenizer.from_pretrained(str_path, **tok_config)
3236
transformer = AutoModel.from_pretrained(str_path)
37+
transformer.forward = partial(transformer.forward, **trf_config)
3338
ops = get_current_ops()
3439
if isinstance(ops, CupyOps):
3540
transformer.cuda()

0 commit comments

Comments
 (0)