Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
552ac3c
add utilities till TFData2VecVisionLayer.
sayakpaul Apr 29, 2022
65e9d3b
chore: pass window_size to attention layer.
sayakpaul Apr 29, 2022
59e244d
feat: add TFData2VecVisionRelativePositionBias.
sayakpaul Apr 29, 2022
5b99dc5
feat: initial implementation ready for tf data2vec.
sayakpaul Apr 29, 2022
caac111
fix: relative position bias index, table to be fixed.
sayakpaul Apr 30, 2022
26a54e3
chore: implementation added, tests remaining.
sayakpaul May 1, 2022
fa58b27
add: tests, other PR files.
sayakpaul May 1, 2022
060b42a
fix: code quality.
sayakpaul May 1, 2022
4b0eb6f
fix: import structure in init.
sayakpaul May 1, 2022
5f6572c
chore: run make fix-copies.
sayakpaul May 2, 2022
2048224
chore: address PR feedback (round I).
sayakpaul May 3, 2022
bae3243
chore: styling nit.
sayakpaul May 3, 2022
4fc2e3e
fix: tests due to removal of to_2tuple().
sayakpaul May 3, 2022
ddd6b1c
chore: rebase with upstream main and move the test.
sayakpaul May 3, 2022
d24cd4e
Update src/transformers/models/auto/modeling_tf_auto.py
sayakpaul May 3, 2022
7765ed0
Update src/transformers/models/auto/modeling_tf_auto.py
sayakpaul May 3, 2022
77822c0
fix: layer call.
sayakpaul May 3, 2022
b662294
chore: remove from_pt=True and rerun test.
sayakpaul May 3, 2022
1c1efc6
chore: remove cast and tf.divide.
sayakpaul May 4, 2022
c96ffc0
chore: minor edits to the test script.
sayakpaul May 4, 2022
6b526e5
Update src/transformers/models/data2vec/modeling_tf_data2vec_vision.py
sayakpaul May 4, 2022
497a54d
fix: expand() on TF tensors with broadcast_to().
sayakpaul May 4, 2022
d004de3
fix: test import.
sayakpaul May 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ Flax), PyTorch, and/or TensorFlow.
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
| Data2VecAudio | ❌ | ❌ | ✅ | ❌ | ❌ |
| Data2VecText | ❌ | ❌ | ✅ | ❌ | ❌ |
| Data2VecVision | ❌ | ❌ | ✅ | | ❌ |
| Data2VecVision | ❌ | ❌ | ✅ | | ❌ |
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
| DeBERTa-v2 | ✅ | ✅ | ✅ | ✅ | ❌ |
| Decision Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
Expand Down
16 changes: 14 additions & 2 deletions docs/source/en/model_doc/data2vec.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ Tips:
- For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization.
- For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction.

This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten)
This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten).
[sayakpaul](https://github.com/sayakpaul) contributed Data2Vec for vision in TensorFlow.

The original code can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
The original code (for NLP and Speech) can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
The original code for vision can be found [here](https://github.com/facebookresearch/data2vec_vision/tree/main/beit).


## Data2VecTextConfig
Expand Down Expand Up @@ -130,3 +132,13 @@ The original code can be found [here](https://github.com/pytorch/fairseq/tree/ma

[[autodoc]] Data2VecVisionForSemanticSegmentation
- forward

## TFData2VecVisionModel

[[autodoc]] TFData2VecVisionModel
- call

## TFData2VecVisionForImageClassification

[[autodoc]] TFData2VecVisionForImageClassification
- call
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,13 @@
"TFCTRLPreTrainedModel",
]
)
_import_structure["models.data2vec"].extend(
[
"TFData2VecVisionForImageClassification",
"TFData2VecVisionModel",
"TFData2VecVisionPreTrainedModel",
]
)
_import_structure["models.deberta"].extend(
[
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -4029,6 +4036,11 @@
TFCTRLModel,
TFCTRLPreTrainedModel,
)
from .models.data2vec import (
TFData2VecVisionForImageClassification,
TFData2VecVisionModel,
TFData2VecVisionPreTrainedModel,
)
from .models.deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("data2vec-vision", "TFData2VecVisionModel"),
("led", "TFLEDModel"),
("lxmert", "TFLxmertModel"),
("mt5", "TFMT5Model"),
Expand Down Expand Up @@ -163,6 +164,7 @@
# Model for Image-classsification
("vit", "TFViTForImageClassification"),
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
]
)

Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/data2vec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from typing import TYPE_CHECKING

from transformers.utils.import_utils import is_tf_available

from ...utils import _LazyModule, is_torch_available


Expand Down Expand Up @@ -68,6 +70,13 @@
"Data2VecVisionPreTrainedModel",
]

if is_tf_available():
_import_structure["modeling_tf_data2vec_vision"] = [
"TFData2VecVisionForImageClassification",
"TFData2VecVisionModel",
"TFData2VecVisionPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_data2vec_audio import DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP, Data2VecAudioConfig
from .configuration_data2vec_text import (
Expand Down Expand Up @@ -110,6 +119,12 @@
Data2VecVisionModel,
Data2VecVisionPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_data2vec_vision import (
TFData2VecVisionForImageClassification,
TFData2VecVisionModel,
TFData2VecVisionPreTrainedModel,
)

else:
import sys
Expand Down
Loading