Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
- Blenderbot
- BlenderbotSmall
- CamemBERT
- ConvBERT
- Data2VecText
- Data2VecVision
- DistilBERT
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/convbert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


_import_structure = {
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig"],
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertOnnxConfig"],
"tokenization_convbert": ["ConvBertTokenizer"],
}

Expand Down Expand Up @@ -58,7 +58,7 @@


if TYPE_CHECKING:
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertOnnxConfig
from .tokenization_convbert import ConvBertTokenizer

if is_tokenizers_available():
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/models/convbert/configuration_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
# limitations under the License.
""" ConvBERT model configuration"""

from collections import OrderedDict
from typing import Mapping

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


Expand Down Expand Up @@ -138,3 +142,20 @@ def __init__(
self.conv_kernel_size = conv_kernel_size
self.num_groups = num_groups
self.classifier_dropout = classifier_dropout


# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
class ConvBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
10 changes: 10 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.convbert import ConvBertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
Expand Down Expand Up @@ -187,6 +188,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=CamembertOnnxConfig,
),
"convbert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=ConvBertOnnxConfig,
),
"distilbert": supported_features_mapping(
"default",
"masked-lm",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def test_values_override(self):
("bigbird", "google/bigbird-roberta-base"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("convbert", "YituTech/conv-bert-base"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("roberta", "roberta-base"),
Expand Down