Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
24f8f6f
move encoder below
geetu040 Feb 25, 2025
7d60223
auto modeling
geetu040 Feb 25, 2025
60d77b3
Merge branch 'main' into sam-vision-encoder
geetu040 Feb 26, 2025
6a65a88
write SamVisionTester
geetu040 Feb 26, 2025
d2a4083
fix vision attention shape
geetu040 Feb 26, 2025
1cdf8c7
fix SamVisionTest
geetu040 Feb 26, 2025
61626ca
minor changes to SamVisionTest
geetu040 Feb 26, 2025
a1258f3
Revert "fix vision attention shape"
geetu040 Mar 2, 2025
cd42ffb
fix attention output shape in new tests
geetu040 Mar 2, 2025
2af72b5
remove encoder examples
geetu040 Mar 2, 2025
bdac520
run modular on got_ocr2
geetu040 Mar 2, 2025
d5ff273
code formatting
geetu040 Mar 2, 2025
9ae0b98
fix got_ocr2
geetu040 Mar 2, 2025
a4a60fc
ruff fixes
geetu040 Mar 2, 2025
6934b1c
code quality
geetu040 Mar 2, 2025
aaf6c53
add sam_vision in auto modeling and auto configuration
geetu040 Mar 3, 2025
760d2d2
remove composite test
geetu040 Mar 3, 2025
5336bac
updated index.md
geetu040 Mar 3, 2025
d0e7d18
Merge branch "main" and resolve conflicts
geetu040 Mar 4, 2025
1766a0f
add TFSamVisionEncoder to __init__
geetu040 Mar 4, 2025
06cce05
Merge remote-tracking branch 'origin/main' into sam-vision-encoder
geetu040 Mar 7, 2025
88ea9d4
fix public TFSamVisionEncoder
geetu040 Mar 7, 2025
dc8ea1b
remove outdated todo comment
geetu040 Mar 7, 2025
f791dc4
Merge branch 'main' into sam-vision-encoder
geetu040 Mar 17, 2025
f148ff3
set test_torch_exportable
geetu040 Mar 17, 2025
caca906
rename: VisionEncoder -> VisionModel
geetu040 Mar 17, 2025
6ed683e
bring back original SamVisionEncoder
geetu040 Mar 17, 2025
159022f
rename back: VisionEncoderOutput -> VisionModelOutput
geetu040 Mar 17, 2025
a1652c6
undo changes in SamModelTester
geetu040 Mar 17, 2025
08e1e5d
Merge branch 'main' into sam-vision-encoder
geetu040 Mar 18, 2025
51de71d
reuse SamVisionEncoder in SamVisionModel
geetu040 Mar 18, 2025
26af99c
Merge branch 'main' into sam-vision-encoder
geetu040 Mar 26, 2025
64080f4
Merge branch 'main' into sam-vision-encoder
geetu040 Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/en/model_doc/sam.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,24 @@ alt="drawing" width="900"/>
[[autodoc]] SamImageProcessor


## SamVisionModel

[[autodoc]] SamVisionModel
- forward


## SamModel

[[autodoc]] SamModel
- forward


## TFSamVisionModel

[[autodoc]] TFSamVisionModel
- call


## TFSamModel

[[autodoc]] TFSamModel
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3560,6 +3560,7 @@
[
"SamModel",
"SamPreTrainedModel",
"SamVisionModel",
]
)
_import_structure["models.seamless_m4t"].extend(
Expand Down Expand Up @@ -4728,6 +4729,7 @@
[
"TFSamModel",
"TFSamPreTrainedModel",
"TFSamVisionModel",
]
)
_import_structure["models.segformer"].extend(
Expand Down Expand Up @@ -8377,6 +8379,7 @@
from .models.sam import (
SamModel,
SamPreTrainedModel,
SamVisionModel,
)
from .models.seamless_m4t import (
SeamlessM4TCodeHifiGan,
Expand Down Expand Up @@ -9318,6 +9321,7 @@
from .models.sam import (
TFSamModel,
TFSamPreTrainedModel,
TFSamVisionModel,
)
from .models.segformer import (
TFSegformerDecodeHead,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@
("rt_detr_v2", "RTDetrV2Config"),
("rwkv", "RwkvConfig"),
("sam", "SamConfig"),
("sam_vision_model", "SamVisionConfig"),
("seamless_m4t", "SeamlessM4TConfig"),
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
("segformer", "SegformerConfig"),
Expand Down Expand Up @@ -624,6 +625,7 @@
("rt_detr_v2", "RT-DETRv2"),
("rwkv", "RWKV"),
("sam", "SAM"),
("sam_vision_model", "SamVisionModel"),
("seamless_m4t", "SeamlessM4T"),
("seamless_m4t_v2", "SeamlessM4Tv2"),
("segformer", "SegFormer"),
Expand Down Expand Up @@ -767,6 +769,7 @@
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("sam_vision_model", "sam"),
]
)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@
("rt_detr_v2", "RTDetrV2Model"),
("rwkv", "RwkvModel"),
("sam", "SamModel"),
("sam_vision_model", "SamVisionModel"),
("seamless_m4t", "SeamlessM4TModel"),
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
("segformer", "SegformerModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("sam", "TFSamModel"),
("sam_vision_model", "TFSamVisionModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swiftformer", "TFSwiftFormerModel"),
Expand Down
20 changes: 19 additions & 1 deletion src/transformers/models/sam/configuration_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,27 @@ class SamVisionConfig(PretrainedConfig):
mlp_dim (`int`, *optional*):
The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
hidden_size`.
"""

Example:

```python
>>> from transformers import (
... SamVisionConfig,
... SamVisionModel,
... )

>>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration
>>> configuration = SamVisionConfig()

>>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
>>> model = SamVisionModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

base_config_key = "vision_config"
model_type = "sam_vision_model"

def __init__(
self,
Expand Down
65 changes: 63 additions & 2 deletions src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig


Expand Down Expand Up @@ -1280,6 +1286,61 @@ def _init_weights(self, module):
"""


SAM_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
"""The vision model from Sam without any head or projection on top.""",
SAM_START_DOCSTRING,
)
class SamVisionModel(SamPreTrainedModel):
config_class = SamVisionConfig
main_input_name = "pixel_values"

def __init__(self, config: SamVisionConfig):
super().__init__(config)
self.vision_encoder = SamVisionEncoder(config)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_encoder.patch_embed

@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SamVisionEncoderOutput, config_class=SamVisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SamVisionEncoderOutput]:
r"""
Returns:

"""
return self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)


@add_start_docstrings(
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
" optional 2D location and bounding boxes.",
Expand Down Expand Up @@ -1522,4 +1583,4 @@ def forward(
)


__all__ = ["SamModel", "SamPreTrainedModel"]
__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]
74 changes: 72 additions & 2 deletions src/transformers/models/sam/modeling_tf_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
from ...tf_utils import flatten, functional_layernorm
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig


Expand Down Expand Up @@ -1400,6 +1406,70 @@ class TFSamPreTrainedModel(TFPreTrainedModel):
"""


SAM_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
"""The vision model from Sam without any head or projection on top.""",
SAM_START_DOCSTRING,
)
class TFSamVisionModel(TFSamPreTrainedModel):
config_class = SamVisionConfig
main_input_name = "pixel_values"

def __init__(self, config: SamVisionConfig, **kwargs):
super().__init__(config, **kwargs)
self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder")

def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vision_encoder", None) is not None:
with tf.name_scope(self.vision_encoder.name):
self.vision_encoder.build(None)

def get_input_embeddings(self):
return self.vision_encoder.patch_embed

@unpack_inputs
@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig)
def call(
self,
pixel_values: TFModelInputType | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
training: bool = False,
**kwargs,
) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]:
r"""
Returns:

"""
return self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)


@add_start_docstrings(
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
" optional 2D location and bounding boxes.",
Expand Down Expand Up @@ -1653,4 +1723,4 @@ def build(self, input_shape=None):
self.mask_decoder.build(None)


__all__ = ["TFSamModel", "TFSamPreTrainedModel"]
__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"]
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8731,6 +8731,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class SamVisionModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class SeamlessM4TCodeHifiGan(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_tf_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2375,6 +2375,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFSamVisionModel(metaclass=DummyObject):
_backends = ["tf"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFSegformerDecodeHead(metaclass=DummyObject):
_backends = ["tf"]

Expand Down
Loading