Skip to content

Commit 21e83d8

Browse files
committed
add funnel transformer
1 parent f91d9a2 commit 21e83d8

File tree

4 files changed

+98
-4
lines changed

4 files changed

+98
-4
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,28 @@ caption = torch.randint(0, 20000, (1, 1024))
155155
encoded = encoder(img, return_embeddings = True)
156156
decoder(caption, context = encoded) # (1, 1024, 20000)
157157
```
158+
159+
Funnel Transformer for Encoder
160+
161+
```python
162+
import torch
163+
from x_transformers import TransformerWrapper, FunnelEncoder
164+
165+
model = TransformerWrapper(
166+
num_tokens = 20000,
167+
max_seq_len = 1024,
168+
num_memory_tokens = 10,
169+
attn_layers = FunnelEncoder(
170+
dim = 512,
171+
heads = 8,
172+
depths = (4, 4, 4)
173+
)
174+
)
175+
176+
x = torch.randint(1, 20000, (1, 1024))
177+
model(x) # (1, 1024, 20000)
178+
```
179+
158180
## Citations
159181

160182
```bibtex

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'x-transformers',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.0.21',
6+
version = '0.0.22',
77
license='MIT',
88
description = 'X-Transformers - Pytorch',
99
author = 'Phil Wang',

x_transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from x_transformers.x_transformers import XTransformer, Encoder, Decoder, TransformerWrapper, ViTransformerWrapper
2+
from x_transformers.x_transformers import FunnelEncoder
23
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

x_transformers/x_transformers.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from functools import partial
66
from inspect import isfunction
77

8-
from einops import rearrange, repeat
8+
from einops import rearrange, repeat, reduce
99
from entmax import entmax15
1010

1111
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
@@ -20,6 +20,11 @@ def default(val, d):
2020
return val
2121
return d() if isfunction(d) else d
2222

23+
def residualize(f):
24+
def fn(x, *args, **kwargs):
25+
return f(x, *args, **kwargs) + x
26+
return fn
27+
2328
# keyword argument helpers
2429

2530
def pick_and_pop(keys, d):
@@ -234,8 +239,19 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
234239
out = rearrange(out, 'b h n d -> b n (h d)')
235240
return self.to_out(out)
236241

242+
class AttentionWithDownsample(Attention):
243+
def forward(self, x, num_memory_tokens = 0, downsample = False, **kwargs):
244+
if downsample:
245+
b, n, *_ = x.shape
246+
mem, x = x[:, :num_memory_tokens], x[:, num_memory_tokens:]
247+
x, remainder = (x[:, :-1], x[:, -1:]) if (n % 2) == 1 else (x[:, :], x[:, 0:0])
248+
x = reduce(x, 'b (n c) d -> b n d', 'mean', c = 2)
249+
x = torch.cat((mem, x, remainder), dim = 1)
250+
251+
return super().forward(x, **kwargs)
252+
237253
class Encoder(nn.Module):
238-
def __init__(self, dim, depth, dim_head = 64, heads = 8, use_scalenorm = False, rel_pos_bias = False, **kwargs):
254+
def __init__(self, dim, depth, heads = 8, use_scalenorm = False, rel_pos_bias = False, **kwargs):
239255
super().__init__()
240256
self.dim = dim
241257
self.layers = nn.ModuleList([])
@@ -249,7 +265,7 @@ def __init__(self, dim, depth, dim_head = 64, heads = 8, use_scalenorm = False,
249265

250266
for _ in range(depth):
251267
self.layers.append(nn.ModuleList([
252-
prenorm_fn(Attention(dim, dim_head = dim_head, heads = heads, **attn_kwargs)),
268+
prenorm_fn(Attention(dim, heads = heads, **attn_kwargs)),
253269
prenorm_fn(FeedForward(dim, **ff_kwargs))
254270
]))
255271
def forward(self, x, context = None, mask = None):
@@ -258,6 +274,57 @@ def forward(self, x, context = None, mask = None):
258274
x = ff(x) + x
259275
return x
260276

277+
class FunnelEncoder(nn.Module):
278+
def __init__(self, dim, depths, heads = 8, use_scalenorm = False, rel_pos_bias = False, num_memory_tokens = None, **kwargs):
279+
super().__init__()
280+
assert isinstance(depths, tuple), 'depths must be a tuple, where each element specifies the number of layers before the next bottleneck'
281+
assert len(depths) > 1, 'there must be at least 1 bottleneck'
282+
283+
self.dim = dim
284+
self.num_memory_tokens = num_memory_tokens
285+
self.rel_pos = RelativePositionBias() if rel_pos_bias else None
286+
self.bottlenecks = nn.ModuleList([])
287+
288+
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
289+
prenorm_fn = partial(PreNorm, dim, norm_class = norm_class)
290+
291+
ff_kwargs, kwargs = group_by_key_prefix_and_trim('ff_', kwargs)
292+
attn_kwargs, _ = group_by_key_prefix_and_trim('attn_', kwargs)
293+
294+
for depth in depths:
295+
layers = nn.ModuleList([])
296+
for _ in range(depth):
297+
layers.append(nn.ModuleList([
298+
prenorm_fn(AttentionWithDownsample(dim, heads = heads, **attn_kwargs)),
299+
prenorm_fn(FeedForward(dim, **ff_kwargs))
300+
]))
301+
self.bottlenecks.append(layers)
302+
303+
def forward(self, x, context = None, mask = None):
304+
n = x.shape[1]
305+
num_mem = self.num_memory_tokens
306+
num_downsamples = len(self.bottlenecks)
307+
308+
for layer_ind, layers in enumerate(self.bottlenecks):
309+
if layer_ind == 1:
310+
res = x
311+
312+
for ind, (self_attn, ff) in enumerate(layers):
313+
downsample = layer_ind != 0 and ind == 0
314+
self_attn = residualize(self_attn) if not downsample else self_attn
315+
316+
x = self_attn(x, mask = mask, rel_pos = self.rel_pos, downsample = downsample, num_memory_tokens = num_mem)
317+
x = ff(x) + x
318+
319+
mem, x = x[:, :num_mem], x[:, num_mem:]
320+
# upsample by repeating tokens as specified in paper
321+
x = repeat(x, 'b n d -> b (n m) d', m = 2 ** (num_downsamples - 1))
322+
# curtail any excessive tokens
323+
x = x[:, :(n - num_mem)]
324+
x = torch.cat((mem, x), dim = 1)
325+
# add to residual before start of first downsample
326+
return x + res
327+
261328
class Decoder(nn.Module):
262329
def __init__(self, dim, depth, dim_head = 64, heads = 8, cross_attend = False, use_scalenorm = False, rel_pos_bias = False, **kwargs):
263330
super().__init__()
@@ -361,6 +428,10 @@ def __init__(
361428
if num_memory_tokens > 0:
362429
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
363430

431+
# let funnel encoder know number of memory tokens, if specified
432+
if isinstance(attn_layers, FunnelEncoder):
433+
attn_layers.num_memory_tokens = num_memory_tokens
434+
364435
def init_(self):
365436
nn.init.normal_(self.token_emb.weight, std = 0.02)
366437
nn.init.normal_(self.pos_emb.weight, std = 0.02)

0 commit comments

Comments
 (0)