Skip to content

Commit cf4035d

Browse files
committed
create a cross attention only attention layer (CrossAttender)
1 parent 7b80b96 commit cf4035d

File tree

4 files changed

+9
-11
lines changed

4 files changed

+9
-11
lines changed

README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -519,22 +519,16 @@ Cross Attention
519519

520520
```python
521521
import torch
522-
from x_transformers import Encoder
522+
from x_transformers import Encoder, CrossAttender
523523

524524
enc = Encoder(dim = 512, depth = 6)
525-
526-
cross_attn = Encoder(
527-
dim = 512,
528-
depth = 6,
529-
cross_attend = True,
530-
only_cross = True
531-
)
525+
model = CrossAttender(dim = 512, depth = 6)
532526

533527
nodes = torch.randn(1, 1, 512)
534528
neighbors = torch.randn(1, 5, 512)
535529

536530
encoded_neighbors = enc(neighbors)
537-
cross_attn(nodes, context = encoded_neighbors) # (1, 1, 512)
531+
model(nodes, context = encoded_neighbors) # (1, 1, 512)
538532
```
539533

540534
## Citations

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'x-transformers',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.3.3',
6+
version = '0.3.4',
77
license='MIT',
88
description = 'X-Transformers - Pytorch',
99
author = 'Phil Wang',

x_transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from x_transformers.x_transformers import XTransformer, Encoder, Decoder, TransformerWrapper, ViTransformerWrapper
1+
from x_transformers.x_transformers import XTransformer, Encoder, Decoder, CrossAttender, TransformerWrapper, ViTransformerWrapper
22
from x_transformers.funnel import FunnelEncoder
33
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

x_transformers/x_transformers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ def __init__(self, **kwargs):
353353
assert 'causal' not in kwargs, 'cannot set causality on decoder'
354354
super().__init__(causal = True, **kwargs)
355355

356+
class CrossAttender(AttentionLayers):
357+
def __init__(self, **kwargs):
358+
super().__init__(cross_attend = True, only_cross = True, **kwargs)
359+
356360
class ViTransformerWrapper(nn.Module):
357361
def __init__(
358362
self,

0 commit comments

Comments
 (0)