Skip to content

Commit 3080bb4

Browse files
mht-sharmalewtun
andauthored
Add onnx support for VisionEncoderDecoder (#19254)
* Add onnx support for VisionEncoderDecoder * Add onnx support for VisionEncoderDecoder * Removed unused import * Rename encoder hidden state Co-authored-by: lewtun <[email protected]> * Update docstrings and removed redundant code * Added test function for enc-dec models * Update doc string text Co-authored-by: lewtun <[email protected]> * fixed code style Co-authored-by: lewtun <[email protected]>
1 parent 298f6a9 commit 3080bb4

File tree

7 files changed

+305
-35
lines changed

7 files changed

+305
-35
lines changed

docs/source/en/serialization.mdx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Ready-made configurations include the following architectures:
9696
- SqueezeBERT
9797
- Swin Transformer
9898
- T5
99+
- Vision Encoder decoder
99100
- ViT
100101
- XLM
101102
- XLM-RoBERTa
@@ -294,6 +295,13 @@ that can be used for fast autoregressive decoding.
294295

295296
</Tip>
296297

298+
<Tip>
299+
300+
For `VisionEncoderDecoder` type models, the encoder and decoder parts are
301+
exported separately as two ONNX files named `encoder_model.onnx` and `decoder_model.onnx` respectively.
302+
303+
</Tip>
304+
297305

298306
## Exporting a model for an unsupported architecture
299307

src/transformers/models/vision_encoder_decoder/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
)
2828

2929

30-
_import_structure = {"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"]}
30+
_import_structure = {
31+
"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig", "VisionEncoderDecoderOnnxConfig"]
32+
}
3133

3234
try:
3335
if not is_torch_available():
@@ -54,7 +56,7 @@
5456
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
5557

5658
if TYPE_CHECKING:
57-
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
59+
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig, VisionEncoderDecoderOnnxConfig
5860

5961
try:
6062
if not is_torch_available():

src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
# limitations under the License.
1616

1717
import copy
18+
from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict
19+
20+
from packaging import version
1821

1922
from ...configuration_utils import PretrainedConfig
23+
from ...onnx import OnnxConfig
2024
from ...utils import logging
2125
from ..auto.configuration_auto import AutoConfig
2226

2327

28+
if TYPE_CHECKING:
29+
from ... import PreTrainedTokenizerBase, TensorType
30+
2431
logger = logging.get_logger(__name__)
2532

2633

@@ -119,3 +126,97 @@ def to_dict(self):
119126
output["decoder"] = self.decoder.to_dict()
120127
output["model_type"] = self.__class__.model_type
121128
return output
129+
130+
131+
class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig):
132+
torch_onnx_minimum_version = version.parse("1.11")
133+
134+
@property
135+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
136+
return OrderedDict(
137+
[
138+
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
139+
]
140+
)
141+
142+
@property
143+
def atol_for_validation(self) -> float:
144+
return 1e-4
145+
146+
@property
147+
def outputs(self) -> Mapping[str, Mapping[int, str]]:
148+
return OrderedDict({"last_hidden_state": {0: "batch", 1: "encoder_sequence"}})
149+
150+
151+
class VisionEncoderDecoderDecoderOnnxConfig(OnnxConfig):
152+
@property
153+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
154+
common_inputs = OrderedDict()
155+
common_inputs["input_ids"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
156+
common_inputs["attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
157+
common_inputs["encoder_hidden_states"] = {0: "batch", 1: "encoder_sequence"}
158+
159+
return common_inputs
160+
161+
def generate_dummy_inputs(
162+
self,
163+
tokenizer: "PreTrainedTokenizerBase",
164+
batch_size: int = -1,
165+
seq_length: int = -1,
166+
is_pair: bool = False,
167+
framework: Optional["TensorType"] = None,
168+
) -> Mapping[str, Any]:
169+
import torch
170+
171+
common_inputs = OrderedDict()
172+
173+
dummy_input = super().generate_dummy_inputs(
174+
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
175+
)
176+
177+
batch, encoder_sequence = dummy_input["input_ids"].shape
178+
encoder_hidden_states_shape = (batch, encoder_sequence, self._config.encoder_hidden_size)
179+
common_inputs["input_ids"] = dummy_input.pop("input_ids")
180+
common_inputs["attention_mask"] = dummy_input.pop("attention_mask")
181+
common_inputs["encoder_hidden_states"] = torch.zeros(encoder_hidden_states_shape)
182+
183+
return common_inputs
184+
185+
186+
class VisionEncoderDecoderOnnxConfig(OnnxConfig):
187+
@property
188+
def inputs(self) -> None:
189+
pass
190+
191+
def get_encoder_config(self, encoder_config: PretrainedConfig) -> OnnxConfig:
192+
r"""
193+
Returns ONNX encoder config for `VisionEncoderDecoder` model.
194+
195+
Args:
196+
encoder_config (`PretrainedConfig`):
197+
The encoder model's configuration to use when exporting to ONNX.
198+
199+
Returns:
200+
[`VisionEncoderDecoderEncoderOnnxConfig`]: An instance of the ONNX configuration object
201+
"""
202+
return VisionEncoderDecoderEncoderOnnxConfig(encoder_config)
203+
204+
def get_decoder_config(
205+
self, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, feature: str = "default"
206+
) -> OnnxConfig:
207+
r"""
208+
Returns ONNX decoder config for `VisionEncoderDecoder` model.
209+
210+
Args:
211+
encoder_config (`PretrainedConfig`):
212+
The encoder model's configuration to use when exporting to ONNX.
213+
decoder_config (`PretrainedConfig`):
214+
The decoder model's configuration to use when exporting to ONNX
215+
feature (`str`, *optional*):
216+
The type of feature to export the model with.
217+
218+
Returns:
219+
[`VisionEncoderDecoderDecoderOnnxConfig`]: An instance of the ONNX configuration object.
220+
"""
221+
decoder_config.encoder_hidden_size = encoder_config.hidden_size
222+
return VisionEncoderDecoderDecoderOnnxConfig(decoder_config, feature)

src/transformers/onnx/__main__.py

Lines changed: 95 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from .features import FeaturesManager
2323

2424

25+
ENCODER_DECODER_MODELS = ["vision-encoder-decoder"]
26+
27+
2528
def main():
2629
parser = ArgumentParser("Hugging Face Transformers ONNX exporter")
2730
parser.add_argument(
@@ -65,48 +68,110 @@ def main():
6568
if not args.output.parent.exists():
6669
args.output.parent.mkdir(parents=True)
6770

68-
# Instantiate the appropriate preprocessor
69-
if args.preprocessor == "auto":
70-
preprocessor = get_preprocessor(args.model)
71-
elif args.preprocessor == "tokenizer":
72-
preprocessor = AutoTokenizer.from_pretrained(args.model)
73-
elif args.preprocessor == "feature_extractor":
74-
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
75-
elif args.preprocessor == "processor":
76-
preprocessor = AutoProcessor.from_pretrained(args.model)
77-
else:
78-
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
79-
8071
# Allocate the model
8172
model = FeaturesManager.get_model_from_feature(
8273
args.feature, args.model, framework=args.framework, cache_dir=args.cache_dir
8374
)
75+
8476
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
8577
onnx_config = model_onnx_config(model.config)
8678

87-
# Ensure the requested opset is sufficient
88-
if args.opset is None:
89-
args.opset = onnx_config.default_onnx_opset
79+
if model_kind in ENCODER_DECODER_MODELS:
80+
encoder_model = model.get_encoder()
81+
decoder_model = model.get_decoder()
9082

91-
if args.opset < onnx_config.default_onnx_opset:
92-
raise ValueError(
93-
f"Opset {args.opset} is not sufficient to export {model_kind}. "
94-
f"At least {onnx_config.default_onnx_opset} is required."
83+
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
84+
decoder_onnx_config = onnx_config.get_decoder_config(
85+
encoder_model.config, decoder_model.config, feature=args.feature
9586
)
9687

97-
onnx_inputs, onnx_outputs = export(
98-
preprocessor,
99-
model,
100-
onnx_config,
101-
args.opset,
102-
args.output,
103-
)
88+
if args.opset is None:
89+
args.opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
90+
91+
if args.opset < min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset):
92+
raise ValueError(
93+
f"Opset {args.opset} is not sufficient to export {model_kind}. At least "
94+
f" {min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)} is required."
95+
)
96+
97+
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
98+
99+
onnx_inputs, onnx_outputs = export(
100+
preprocessor,
101+
encoder_model,
102+
encoder_onnx_config,
103+
args.opset,
104+
args.output.parent.joinpath("encoder_model.onnx"),
105+
)
106+
107+
validate_model_outputs(
108+
encoder_onnx_config,
109+
preprocessor,
110+
encoder_model,
111+
args.output.parent.joinpath("encoder_model.onnx"),
112+
onnx_outputs,
113+
args.atol if args.atol else encoder_onnx_config.atol_for_validation,
114+
)
115+
116+
preprocessor = AutoTokenizer.from_pretrained(args.model)
117+
118+
onnx_inputs, onnx_outputs = export(
119+
preprocessor,
120+
decoder_model,
121+
decoder_onnx_config,
122+
args.opset,
123+
args.output.parent.joinpath("decoder_model.onnx"),
124+
)
125+
126+
validate_model_outputs(
127+
decoder_onnx_config,
128+
preprocessor,
129+
decoder_model,
130+
args.output.parent.joinpath("decoder_model.onnx"),
131+
onnx_outputs,
132+
args.atol if args.atol else decoder_onnx_config.atol_for_validation,
133+
)
134+
logger.info(
135+
f"All good, model saved at: {args.output.parent.joinpath('encoder_model.onnx').as_posix()},"
136+
f" {args.output.parent.joinpath('decoder_model.onnx').as_posix()}"
137+
)
138+
139+
else:
140+
# Instantiate the appropriate preprocessor
141+
if args.preprocessor == "auto":
142+
preprocessor = get_preprocessor(args.model)
143+
elif args.preprocessor == "tokenizer":
144+
preprocessor = AutoTokenizer.from_pretrained(args.model)
145+
elif args.preprocessor == "feature_extractor":
146+
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
147+
elif args.preprocessor == "processor":
148+
preprocessor = AutoProcessor.from_pretrained(args.model)
149+
else:
150+
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
151+
152+
# Ensure the requested opset is sufficient
153+
if args.opset is None:
154+
args.opset = onnx_config.default_onnx_opset
155+
156+
if args.opset < onnx_config.default_onnx_opset:
157+
raise ValueError(
158+
f"Opset {args.opset} is not sufficient to export {model_kind}. "
159+
f"At least {onnx_config.default_onnx_opset} is required."
160+
)
161+
162+
onnx_inputs, onnx_outputs = export(
163+
preprocessor,
164+
model,
165+
onnx_config,
166+
args.opset,
167+
args.output,
168+
)
104169

105-
if args.atol is None:
106-
args.atol = onnx_config.atol_for_validation
170+
if args.atol is None:
171+
args.atol = onnx_config.atol_for_validation
107172

108-
validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol)
109-
logger.info(f"All good, model saved at: {args.output.as_posix()}")
173+
validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol)
174+
logger.info(f"All good, model saved at: {args.output.as_posix()}")
110175

111176

112177
if __name__ == "__main__":

src/transformers/onnx/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class OnnxConfig(ABC):
103103
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
104104
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
105105
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
106+
"vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
106107
}
107108

108109
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
@@ -451,7 +452,6 @@ def generate_dummy_inputs(
451452
is_pair: bool = False,
452453
framework: Optional[TensorType] = None,
453454
) -> Mapping[str, Any]:
454-
455455
# TODO: should we set seq_length = 1 when self.use_past = True?
456456
common_inputs = super().generate_dummy_inputs(
457457
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
@@ -577,7 +577,6 @@ def generate_dummy_inputs(
577577
is_pair: bool = False,
578578
framework: Optional[TensorType] = None,
579579
) -> Mapping[str, Any]:
580-
581580
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
582581
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
583582
)

src/transformers/onnx/features.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
AutoModelForSeq2SeqLM,
3131
AutoModelForSequenceClassification,
3232
AutoModelForTokenClassification,
33+
AutoModelForVision2Seq,
3334
)
3435
if is_tf_available():
3536
from transformers.models.auto import (
@@ -98,6 +99,7 @@ class FeaturesManager:
9899
"image-segmentation": AutoModelForImageSegmentation,
99100
"masked-im": AutoModelForMaskedImageModeling,
100101
"semantic-segmentation": AutoModelForSemanticSegmentation,
102+
"vision2seq-lm": AutoModelForVision2Seq,
101103
}
102104
if is_tf_available():
103105
_TASKS_TO_TF_AUTOMODELS = {
@@ -481,6 +483,9 @@ class FeaturesManager:
481483
"seq2seq-lm-with-past",
482484
onnx_config_cls="models.t5.T5OnnxConfig",
483485
),
486+
"vision-encoder-decoder": supported_features_mapping(
487+
"vision2seq-lm", onnx_config_cls="models.vision_encoder_decoder.VisionEncoderDecoderOnnxConfig"
488+
),
484489
"vit": supported_features_mapping(
485490
"default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
486491
),
@@ -582,6 +587,7 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
582587
raise KeyError(
583588
f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
584589
)
590+
585591
return task_to_automodel[task]
586592

587593
@staticmethod

0 commit comments

Comments
 (0)