Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
24f8f6f
move encoder below
geetu040 Feb 25, 2025
7d60223
auto modeling
geetu040 Feb 25, 2025
60d77b3
Merge branch 'main' into sam-vision-encoder
geetu040 Feb 26, 2025
6a65a88
write SamVisionTester
geetu040 Feb 26, 2025
d2a4083
fix vision attention shape
geetu040 Feb 26, 2025
1cdf8c7
fix SamVisionTest
geetu040 Feb 26, 2025
61626ca
minor changes to SamVisionTest
geetu040 Feb 26, 2025
a1258f3
Revert "fix vision attention shape"
geetu040 Mar 2, 2025
cd42ffb
fix attention output shape in new tests
geetu040 Mar 2, 2025
2af72b5
remove encoder examples
geetu040 Mar 2, 2025
bdac520
run modular on got_ocr2
geetu040 Mar 2, 2025
d5ff273
code formatting
geetu040 Mar 2, 2025
9ae0b98
fix got_ocr2
geetu040 Mar 2, 2025
a4a60fc
ruff fixes
geetu040 Mar 2, 2025
6934b1c
code quality
geetu040 Mar 2, 2025
aaf6c53
add sam_vision in auto modeling and auto configuration
geetu040 Mar 3, 2025
760d2d2
remove composite test
geetu040 Mar 3, 2025
5336bac
updated index.md
geetu040 Mar 3, 2025
d0e7d18
Merge branch "main" and resolve conflicts
geetu040 Mar 4, 2025
1766a0f
add TFSamVisionEncoder to __init__
geetu040 Mar 4, 2025
06cce05
Merge remote-tracking branch 'origin/main' into sam-vision-encoder
geetu040 Mar 7, 2025
88ea9d4
fix public TFSamVisionEncoder
geetu040 Mar 7, 2025
dc8ea1b
remove outdated todo comment
geetu040 Mar 7, 2025
f791dc4
Merge branch 'main' into sam-vision-encoder
geetu040 Mar 17, 2025
f148ff3
set test_torch_exportable
geetu040 Mar 17, 2025
caca906
rename: VisionEncoder -> VisionModel
geetu040 Mar 17, 2025
6ed683e
bring back original SamVisionEncoder
geetu040 Mar 17, 2025
159022f
rename back: VisionEncoderOutput -> VisionModelOutput
geetu040 Mar 17, 2025
a1652c6
undo changes in SamModelTester
geetu040 Mar 17, 2025
08e1e5d
Merge branch 'main' into sam-vision-encoder
geetu040 Mar 18, 2025
51de71d
reuse SamVisionEncoder in SamVisionModel
geetu040 Mar 18, 2025
26af99c
Merge branch 'main' into sam-vision-encoder
geetu040 Mar 26, 2025
64080f4
Merge branch 'main' into sam-vision-encoder
geetu040 Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/model_doc/sam.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ alt="drawing" width="900"/>
[[autodoc]] SamImageProcessor


## SamVisionEncoder

[[autodoc]] SamVisionEncoder
- forward


## SamModel

[[autodoc]] SamModel
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3500,6 +3500,7 @@
[
"SamModel",
"SamPreTrainedModel",
"SamVisionEncoder",
]
)
_import_structure["models.seamless_m4t"].extend(
Expand Down Expand Up @@ -8261,6 +8262,7 @@
from .models.sam import (
SamModel,
SamPreTrainedModel,
SamVisionEncoder,
)
from .models.seamless_m4t import (
SeamlessM4TCodeHifiGan,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@
("rt_detr", "RTDetrModel"),
("rt_detr_v2", "RTDetrV2Model"),
("rwkv", "RwkvModel"),
("sam", "SamVisionEncoder"),
("sam", "SamModel"),
("seamless_m4t", "SeamlessM4TModel"),
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
Expand Down
272 changes: 152 additions & 120 deletions src/transformers/models/got_ocr2/modeling_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,123 @@ def forward(
return outputs


class GotOcr2MultiModalProjector(nn.Module):
def __init__(self, config: GotOcr2Config):
super().__init__()
vision_output_channels = config.vision_config.output_channels
language_hidden_size = config.text_config.hidden_size
self.conv_upsampler1 = nn.Conv2d(
vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False
)
self.conv_upsampler2 = nn.Conv2d(
vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False
)
self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size)

def forward(self, vision_embeddings: torch.Tensor) -> torch.Tensor:
hidden_state = self.conv_upsampler1(vision_embeddings)
hidden_state = self.conv_upsampler2(hidden_state)
hidden_state = hidden_state.flatten(2).permute(0, 2, 1)
hidden_state = self.multimodal_projector(hidden_state)
return hidden_state


@dataclass
class GotOcr2CausalLMOutputWithPast(ModelOutput):
"""
Base class for GotOcr2 causal language model (or autoregressive) outputs.

Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)

Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None


GOT_OCR2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.

Parameters:
config ([`GotOcr2Config`] or [`GotOcr2VisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2PreTrainedModel(PreTrainedModel):
config_class = GotOcr2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["GotOcr2VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True

def _init_weights(self, module):
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/GotOcr2/tree/main/got_ocr2 should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)

if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)

if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()


@dataclass
class GotOcr2VisionEncoderOutput(ModelOutput):
"""
Expand Down Expand Up @@ -408,9 +525,32 @@ def forward(self, hidden_states):
return hidden_states


class GotOcr2VisionEncoder(nn.Module):
GOT_OCR2_VISION_INPUTS_DOCSTRING = """
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`GotOcr2Processor`]. See [`GotOcr2Processor.__call__`] for
details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
"""The vision model from GotOcr2 without any head or projection on top.""",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2VisionEncoder(GotOcr2PreTrainedModel):
config_class = GotOcr2VisionConfig
main_input_name = "pixel_values"

def __init__(self, config: GotOcr2VisionConfig):
super().__init__()
super().__init__(config)
self.config = config
self.image_size = config.image_size

Expand Down Expand Up @@ -440,16 +580,25 @@ def __init__(self, config: GotOcr2VisionConfig):

self.gradient_checkpointing = False

def get_input_embeddings(self):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.patch_embed

@add_start_docstrings_to_model_forward(GOT_OCR2_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2VisionEncoderOutput, config_class=GotOcr2VisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, GotOcr2VisionEncoderOutput]:
r"""
Returns:

"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -503,123 +652,6 @@ def forward(
)


class GotOcr2MultiModalProjector(nn.Module):
def __init__(self, config: GotOcr2Config):
super().__init__()
vision_output_channels = config.vision_config.output_channels
language_hidden_size = config.text_config.hidden_size
self.conv_upsampler1 = nn.Conv2d(
vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False
)
self.conv_upsampler2 = nn.Conv2d(
vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False
)
self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size)

def forward(self, vision_embeddings: torch.Tensor) -> torch.Tensor:
hidden_state = self.conv_upsampler1(vision_embeddings)
hidden_state = self.conv_upsampler2(hidden_state)
hidden_state = hidden_state.flatten(2).permute(0, 2, 1)
hidden_state = self.multimodal_projector(hidden_state)
return hidden_state


@dataclass
class GotOcr2CausalLMOutputWithPast(ModelOutput):
"""
Base class for GotOcr2 causal language model (or autoregressive) outputs.

Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)

Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None


GOT_OCR2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.

Parameters:
config ([`GotOcr2Config`] or [`GotOcr2VisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2PreTrainedModel(PreTrainedModel):
config_class = GotOcr2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["GotOcr2VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True

def _init_weights(self, module):
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/GotOcr2/tree/main/got_ocr2 should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)

if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)

if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()


GOT_OCR2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/got_ocr2/modular_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,10 +695,6 @@ def __init__(self, config, window_size):
self.window_size = window_size


class GotOcr2VisionEncoder(SamVisionEncoder):
pass


class GotOcr2MultiModalProjector(nn.Module):
def __init__(self, config: GotOcr2Config):
super().__init__()
Expand Down Expand Up @@ -728,6 +724,10 @@ class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
pass


class GotOcr2VisionEncoder(SamVisionEncoder):
pass


GOT_OCR2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand Down
Loading