Skip to content

Commit 575aa6e

Browse files
authored
Fix TFSwinSelfAttention to have relative position index as non-trainable weight (#18226)
Signed-off-by: Seunghwan Hong <[email protected]>
1 parent 586dcf6 commit 575aa6e

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

src/transformers/models/swin/modeling_tf_swin.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -461,21 +461,6 @@ def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> No
461461
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
462462
)
463463

464-
# get pair-wise relative position index for each token inside the window
465-
coords_h = tf.range(self.window_size[0])
466-
coords_w = tf.range(self.window_size[1])
467-
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
468-
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
469-
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
470-
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
471-
472-
stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
473-
stack_0 += self.window_size[0] - 1
474-
stack_0 *= 2 * self.window_size[1] - 1
475-
stack_1 += self.window_size[1] - 1
476-
relative_coords = tf.stack([stack_0, stack_1], axis=2)
477-
self.relative_position_index = tf.reduce_sum(relative_coords, axis=-1)
478-
479464
self.query = tf.keras.layers.Dense(
480465
self.all_head_size,
481466
kernel_initializer=get_initializer(config.initializer_range),
@@ -503,6 +488,28 @@ def build(self, input_shape: tf.TensorShape) -> None:
503488
initializer="zeros",
504489
name="relative_position_bias_table",
505490
)
491+
self.relative_position_index = self.add_weight(
492+
shape=(self.window_size[0] ** 2, self.window_size[1] ** 2),
493+
trainable=False,
494+
dtype=tf.int32,
495+
name="relative_position_index",
496+
)
497+
498+
# get pair-wise relative position index for each token inside the window
499+
coords_h = tf.range(self.window_size[0])
500+
coords_w = tf.range(self.window_size[1])
501+
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
502+
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
503+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
504+
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
505+
506+
stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
507+
stack_0 += self.window_size[0] - 1
508+
stack_0 *= 2 * self.window_size[1] - 1
509+
stack_1 += self.window_size[1] - 1
510+
relative_coords = tf.stack([stack_0, stack_1], axis=2)
511+
512+
self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32))
506513
super().build(input_shape)
507514

508515
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:

0 commit comments

Comments
 (0)