@@ -461,21 +461,6 @@ def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> No
461461 window_size if isinstance (window_size , collections .abc .Iterable ) else (window_size , window_size )
462462 )
463463
464- # get pair-wise relative position index for each token inside the window
465- coords_h = tf .range (self .window_size [0 ])
466- coords_w = tf .range (self .window_size [1 ])
467- coords = tf .stack (tf .meshgrid (coords_h , coords_w , indexing = "ij" ))
468- coords_flatten = tf .reshape (coords , (shape_list (coords )[0 ], - 1 ))
469- relative_coords = coords_flatten [:, :, None ] - coords_flatten [:, None , :]
470- relative_coords = tf .transpose (relative_coords , (1 , 2 , 0 ))
471-
472- stack_0 , stack_1 = tf .unstack (relative_coords , axis = 2 )
473- stack_0 += self .window_size [0 ] - 1
474- stack_0 *= 2 * self .window_size [1 ] - 1
475- stack_1 += self .window_size [1 ] - 1
476- relative_coords = tf .stack ([stack_0 , stack_1 ], axis = 2 )
477- self .relative_position_index = tf .reduce_sum (relative_coords , axis = - 1 )
478-
479464 self .query = tf .keras .layers .Dense (
480465 self .all_head_size ,
481466 kernel_initializer = get_initializer (config .initializer_range ),
@@ -503,6 +488,28 @@ def build(self, input_shape: tf.TensorShape) -> None:
503488 initializer = "zeros" ,
504489 name = "relative_position_bias_table" ,
505490 )
491+ self .relative_position_index = self .add_weight (
492+ shape = (self .window_size [0 ] ** 2 , self .window_size [1 ] ** 2 ),
493+ trainable = False ,
494+ dtype = tf .int32 ,
495+ name = "relative_position_index" ,
496+ )
497+
498+ # get pair-wise relative position index for each token inside the window
499+ coords_h = tf .range (self .window_size [0 ])
500+ coords_w = tf .range (self .window_size [1 ])
501+ coords = tf .stack (tf .meshgrid (coords_h , coords_w , indexing = "ij" ))
502+ coords_flatten = tf .reshape (coords , (shape_list (coords )[0 ], - 1 ))
503+ relative_coords = coords_flatten [:, :, None ] - coords_flatten [:, None , :]
504+ relative_coords = tf .transpose (relative_coords , (1 , 2 , 0 ))
505+
506+ stack_0 , stack_1 = tf .unstack (relative_coords , axis = 2 )
507+ stack_0 += self .window_size [0 ] - 1
508+ stack_0 *= 2 * self .window_size [1 ] - 1
509+ stack_1 += self .window_size [1 ] - 1
510+ relative_coords = tf .stack ([stack_0 , stack_1 ], axis = 2 )
511+
512+ self .relative_position_index .assign (tf .cast (tf .reduce_sum (relative_coords , axis = - 1 ), tf .int32 ))
506513 super ().build (input_shape )
507514
508515 def transpose_for_scores (self , x : tf .Tensor ) -> tf .Tensor :
0 commit comments