Skip to content
23 changes: 11 additions & 12 deletions src/transformers/models/got_ocr2/modeling_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.

return rel_pos_resized[relative_coords.long()]

def add_decomposed_rel_pos(
def get_decomposed_rel_pos(
self,
attn: torch.Tensor,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
Expand All @@ -128,8 +127,6 @@ def add_decomposed_rel_pos(
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
Expand All @@ -142,8 +139,8 @@ def add_decomposed_rel_pos(
spatial sequence size of key k with (key_height, key_width).

Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
decomposed_rel_pos (`torch.Tensor`):
decomposed relative position embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
Expand All @@ -154,10 +151,10 @@ def add_decomposed_rel_pos(
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
return attn

decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]

return decomposed_rel_pos

def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
Expand All @@ -173,9 +170,11 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
attn_weights = (query * self.scale) @ key.transpose(-2, -1)

if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
decomposed_rel_pos = self.get_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
attn_weights = attn_weights + decomposed_rel_pos

attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)

Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/sam/image_processing_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
continue
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
out.append({"size": [height, width], "counts": counts})
return out

Expand All @@ -1401,7 +1401,7 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
# Encode run length
out = []
for i in range(batch_size):
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1
if len(cur_idxs) == 0:
# No changes => either all 0 or all 1
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
Expand All @@ -1412,7 +1412,9 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
continue
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
counts += (
[cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()]
)
Comment on lines +1415 to +1417
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need ot add .numpy()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because tensorflow tensors don't have a direct item() method and to access the scalar value, we need to do numpy().item()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok! just wondering if it was broken before

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are a bunch of other tensorflow bugs that I don't why didnot appear in the workflows before, except in #36493, and I have checked for different version of tensorflow they seem to be geniune bugs over all versions.

Copy link
Contributor

@qubvel qubvel Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, not sure if there any usage of TF Sam actually, so there might be some bugs indeed!

out.append({"size": [height, width], "counts": counts})
return out

Expand Down
105 changes: 31 additions & 74 deletions src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,9 +820,8 @@ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.

return rel_pos_resized[relative_coords.long()]

def add_decomposed_rel_pos(
def get_decomposed_rel_pos(
self,
attn: torch.Tensor,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
Expand All @@ -834,8 +833,6 @@ def add_decomposed_rel_pos(
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
Expand All @@ -848,8 +845,8 @@ def add_decomposed_rel_pos(
spatial sequence size of key k with (key_height, key_width).

Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
decomposed_rel_pos (`torch.Tensor`):
decomposed relative position embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
Expand All @@ -860,10 +857,10 @@ def add_decomposed_rel_pos(
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
return attn

decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]

return decomposed_rel_pos

def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
Expand All @@ -879,9 +876,11 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
attn_weights = (query * self.scale) @ key.transpose(-2, -1)

if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
decomposed_rel_pos = self.get_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
attn_weights = attn_weights + decomposed_rel_pos

attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)

Expand Down Expand Up @@ -909,47 +908,19 @@ class SamVisionSdpaAttention(SamVisionAttention):
def __init__(self, config, window_size):
super().__init__(config, window_size)

def add_decomposed_rel_pos(
self,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
This method is reimplemented to follow the implementation in:
https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
This implementation is more memory efficient when using SDPA in the forward method.
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).

Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)

batch_size, _, dim = query.shape
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1)
rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width)

return rel_h, rel_w

def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
if output_attentions:
logger.warning_once(
"`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
output_attentions=output_attentions,
)

batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
Expand All @@ -960,25 +931,21 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
# q, k, v with shape (B * nHead, H * W, C)
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)

rel_h, rel_w = None, None
attn_bias = None
if self.use_rel_pos:
rel_h, rel_w = self.add_decomposed_rel_pos(
decomposed_rel_pos = self.get_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
decomposed_rel_pos = decomposed_rel_pos.reshape(
batch_size, self.num_attention_heads, height * width, height * width
)
attn_bias = decomposed_rel_pos

query = query.view(batch_size, self.num_attention_heads, height * width, -1)
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
value = value.view(batch_size, self.num_attention_heads, height * width, -1)

if self.use_rel_pos:
rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
attn_bias = (rel_h + rel_w).view(
batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
)
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)

attn_output = (
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
Expand All @@ -988,17 +955,7 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch

attn_output = self.proj(attn_output)

if output_attentions:
# For output_attentions, calculate the attention weights
attn_weights = (query @ key.transpose(-2, -1)) * self.scale
if attn_bias is not None:
attn_weights = attn_weights + attn_bias
attn_weights = F.softmax(attn_weights, dim=-1)
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)

return outputs
return attn_output, None


SAM_VISION_ATTENTION_CLASSES = {
Expand Down
25 changes: 13 additions & 12 deletions src/transformers/models/sam/modeling_tf_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,9 +982,8 @@ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor

return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))

def add_decomposed_rel_pos(
def get_decomposed_rel_pos(
self,
attn: tf.Tensor,
query: tf.Tensor,
rel_pos_h: tf.Tensor,
rel_pos_w: tf.Tensor,
Expand All @@ -996,8 +995,6 @@ def add_decomposed_rel_pos(
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

Args:
attn (`tf.Tensor`):
attention map.
query (`tf.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`tf.Tensor`):
Expand All @@ -1010,8 +1007,8 @@ def add_decomposed_rel_pos(
spatial sequence size of key k with (key_height, key_width).

Returns:
attn (`tf.Tensor`):
attention map with added relative positional embeddings.
decomposed_rel_pos (`torch.Tensor`):
decomposed relative position embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
Expand All @@ -1022,10 +1019,12 @@ def add_decomposed_rel_pos(
reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
return attn

rel_h = tf.expand_dims(rel_h, axis=-1)
rel_w = tf.expand_dims(rel_w, axis=-2)
decomposed_rel_pos = rel_h + rel_w

return decomposed_rel_pos

def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
batch_size, height, width, _ = shape_list(hidden_states)
Expand All @@ -1039,9 +1038,11 @@ def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False
attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)

if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
decomposed_rel_pos = self.get_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights))
attn_weights = attn_weights + decomposed_rel_pos

attn_weights = tf.nn.softmax(attn_weights, axis=-1)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg
kwarg_value = kwargs.get(modality_key, "__empty__")
else:
kwarg_value = "__empty__"
if kwarg_value != "__empty__":
if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder, why this was needed? Even if kwargs_value is float, checking kwarg_value != "__empty__" is enough isn't it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was necessary because kwarg_value cannot be directly compared to "__empty__" if it is a TensorFlow tensor. Attempting such a comparison results in a TypeError:

tf.Variable([1, 2, 3, 4]) == "hello"
# TypeError: Cannot convert '__empty__' to EagerTensor of dtype int32

This issue occurred in the test suite test_modeling_tf_sam.py::TFSamModelIntegrationTest (slow tests for sam) within this workflow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, thanks for explanation

output_kwargs[modality][modality_key] = kwarg_value
used_keys.add(modality_key)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/sam/test_processor_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def test_rle_encoding(self):
# This is shape (1, 2, 2).
# Flattened in Fortran order -> [0, 1, 1, 1].
# The RLE for [0,1,1,1] is [1, 3].
input_mask = tf.tensor([[[0, 1], [1, 1]]], dtype=tf.int64)
input_mask = tf.constant([[[0, 1], [1, 1]]], dtype=tf.int64)
rle = _mask_to_rle_tf(input_mask)

self.assertEqual(len(rle), 1)
Expand Down