5
5
from functools import partial
6
6
from inspect import isfunction
7
7
8
- from einops import rearrange , repeat
8
+ from einops import rearrange , repeat , reduce
9
9
from entmax import entmax15
10
10
11
11
from x_transformers .autoregressive_wrapper import AutoregressiveWrapper
@@ -20,6 +20,11 @@ def default(val, d):
20
20
return val
21
21
return d () if isfunction (d ) else d
22
22
23
+ def residualize (f ):
24
+ def fn (x , * args , ** kwargs ):
25
+ return f (x , * args , ** kwargs ) + x
26
+ return fn
27
+
23
28
# keyword argument helpers
24
29
25
30
def pick_and_pop (keys , d ):
@@ -234,8 +239,19 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
234
239
out = rearrange (out , 'b h n d -> b n (h d)' )
235
240
return self .to_out (out )
236
241
242
+ class AttentionWithDownsample (Attention ):
243
+ def forward (self , x , num_memory_tokens = 0 , downsample = False , ** kwargs ):
244
+ if downsample :
245
+ b , n , * _ = x .shape
246
+ mem , x = x [:, :num_memory_tokens ], x [:, num_memory_tokens :]
247
+ x , remainder = (x [:, :- 1 ], x [:, - 1 :]) if (n % 2 ) == 1 else (x [:, :], x [:, 0 :0 ])
248
+ x = reduce (x , 'b (n c) d -> b n d' , 'mean' , c = 2 )
249
+ x = torch .cat ((mem , x , remainder ), dim = 1 )
250
+
251
+ return super ().forward (x , ** kwargs )
252
+
237
253
class Encoder (nn .Module ):
238
- def __init__ (self , dim , depth , dim_head = 64 , heads = 8 , use_scalenorm = False , rel_pos_bias = False , ** kwargs ):
254
+ def __init__ (self , dim , depth , heads = 8 , use_scalenorm = False , rel_pos_bias = False , ** kwargs ):
239
255
super ().__init__ ()
240
256
self .dim = dim
241
257
self .layers = nn .ModuleList ([])
@@ -249,7 +265,7 @@ def __init__(self, dim, depth, dim_head = 64, heads = 8, use_scalenorm = False,
249
265
250
266
for _ in range (depth ):
251
267
self .layers .append (nn .ModuleList ([
252
- prenorm_fn (Attention (dim , dim_head = dim_head , heads = heads , ** attn_kwargs )),
268
+ prenorm_fn (Attention (dim , heads = heads , ** attn_kwargs )),
253
269
prenorm_fn (FeedForward (dim , ** ff_kwargs ))
254
270
]))
255
271
def forward (self , x , context = None , mask = None ):
@@ -258,6 +274,57 @@ def forward(self, x, context = None, mask = None):
258
274
x = ff (x ) + x
259
275
return x
260
276
277
+ class FunnelEncoder (nn .Module ):
278
+ def __init__ (self , dim , depths , heads = 8 , use_scalenorm = False , rel_pos_bias = False , num_memory_tokens = None , ** kwargs ):
279
+ super ().__init__ ()
280
+ assert isinstance (depths , tuple ), 'depths must be a tuple, where each element specifies the number of layers before the next bottleneck'
281
+ assert len (depths ) > 1 , 'there must be at least 1 bottleneck'
282
+
283
+ self .dim = dim
284
+ self .num_memory_tokens = num_memory_tokens
285
+ self .rel_pos = RelativePositionBias () if rel_pos_bias else None
286
+ self .bottlenecks = nn .ModuleList ([])
287
+
288
+ norm_class = ScaleNorm if use_scalenorm else nn .LayerNorm
289
+ prenorm_fn = partial (PreNorm , dim , norm_class = norm_class )
290
+
291
+ ff_kwargs , kwargs = group_by_key_prefix_and_trim ('ff_' , kwargs )
292
+ attn_kwargs , _ = group_by_key_prefix_and_trim ('attn_' , kwargs )
293
+
294
+ for depth in depths :
295
+ layers = nn .ModuleList ([])
296
+ for _ in range (depth ):
297
+ layers .append (nn .ModuleList ([
298
+ prenorm_fn (AttentionWithDownsample (dim , heads = heads , ** attn_kwargs )),
299
+ prenorm_fn (FeedForward (dim , ** ff_kwargs ))
300
+ ]))
301
+ self .bottlenecks .append (layers )
302
+
303
+ def forward (self , x , context = None , mask = None ):
304
+ n = x .shape [1 ]
305
+ num_mem = self .num_memory_tokens
306
+ num_downsamples = len (self .bottlenecks )
307
+
308
+ for layer_ind , layers in enumerate (self .bottlenecks ):
309
+ if layer_ind == 1 :
310
+ res = x
311
+
312
+ for ind , (self_attn , ff ) in enumerate (layers ):
313
+ downsample = layer_ind != 0 and ind == 0
314
+ self_attn = residualize (self_attn ) if not downsample else self_attn
315
+
316
+ x = self_attn (x , mask = mask , rel_pos = self .rel_pos , downsample = downsample , num_memory_tokens = num_mem )
317
+ x = ff (x ) + x
318
+
319
+ mem , x = x [:, :num_mem ], x [:, num_mem :]
320
+ # upsample by repeating tokens as specified in paper
321
+ x = repeat (x , 'b n d -> b (n m) d' , m = 2 ** (num_downsamples - 1 ))
322
+ # curtail any excessive tokens
323
+ x = x [:, :(n - num_mem )]
324
+ x = torch .cat ((mem , x ), dim = 1 )
325
+ # add to residual before start of first downsample
326
+ return x + res
327
+
261
328
class Decoder (nn .Module ):
262
329
def __init__ (self , dim , depth , dim_head = 64 , heads = 8 , cross_attend = False , use_scalenorm = False , rel_pos_bias = False , ** kwargs ):
263
330
super ().__init__ ()
@@ -361,6 +428,10 @@ def __init__(
361
428
if num_memory_tokens > 0 :
362
429
self .memory_tokens = nn .Parameter (torch .randn (num_memory_tokens , dim ))
363
430
431
+ # let funnel encoder know number of memory tokens, if specified
432
+ if isinstance (attn_layers , FunnelEncoder ):
433
+ attn_layers .num_memory_tokens = num_memory_tokens
434
+
364
435
def init_ (self ):
365
436
nn .init .normal_ (self .token_emb .weight , std = 0.02 )
366
437
nn .init .normal_ (self .pos_emb .weight , std = 0.02 )
0 commit comments