Skip to content

Commit 9e4f870

Browse files
committed
add another solution to the attention able to noop issue shared by gpt-oss
1 parent bb68a8c commit 9e4f870

File tree

5 files changed

+51
-5
lines changed

5 files changed

+51
-5
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,4 +2459,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
24592459
}
24602460
```
24612461

2462+
```bibtex
2463+
@misc{openai_gpt_oss,
2464+
author = {OpenAI},
2465+
title = {Introducing gpt-oss},
2466+
howpublished = {https://openai.com/index/introducing-gpt-oss},
2467+
month = {August},
2468+
year = {2025}
2469+
}
2470+
```
2471+
24622472
*solve intelligence... then use that to solve everything else.* - Demis Hassabis

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.6.3"
3+
version = "2.6.4"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_x_transformers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,3 +1235,20 @@ def test_external_key_values():
12351235
additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()
12361236

12371237
logits = model(seq, self_attn_additional_kv = key_values, additional_kv_mask = additional_kv_mask)
1238+
1239+
def test_learned_head_attn_sink():
1240+
1241+
model = TransformerWrapper(
1242+
num_tokens = 20000,
1243+
max_seq_len = 1024,
1244+
attn_layers = Decoder(
1245+
dim = 512,
1246+
depth = 12,
1247+
heads = 8,
1248+
attn_head_learned_sink = True
1249+
)
1250+
)
1251+
1252+
seq = torch.randint(0, 20000, (3, 1024))
1253+
1254+
logits = model(seq)

x_transformers/attend.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from typing import Tuple, Callable
55

66
import torch
7-
from torch.nn import Module
8-
from torch import nn, einsum, Tensor
7+
from torch.nn import Module, Parameter
8+
from torch import cat, nn, einsum, Tensor
99
import torch.nn.functional as F
1010

1111
from collections import namedtuple
@@ -176,6 +176,7 @@ def __init__(
176176
softclamp_logits = False,
177177
logit_softclamp_value = 50.,
178178
add_zero_kv = False,
179+
head_learned_sink = False,
179180
selective = False,
180181
hard = False,
181182
cope = None,
@@ -254,6 +255,13 @@ def __init__(
254255

255256
self.add_zero_kv = add_zero_kv
256257

258+
# learned sink concatted pre-softmax, working solution from gpt-oss
259+
260+
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
261+
262+
self.head_learned_sink = head_learned_sink
263+
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
264+
257265
# soft clamp attention logit value
258266

259267
if softclamp_logits:
@@ -315,10 +323,10 @@ def flash_attn(
315323
if self.l2_distance:
316324
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
317325
k = F.pad(k, (0, 1), value = -1.)
318-
k = torch.cat((k, k_norm_sq), dim = -1)
326+
k = cat((k, k_norm_sq), dim = -1)
319327

320328
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
321-
q = torch.cat((2 * q, q_norm_sq), dim = -1)
329+
q = cat((2 * q, q_norm_sq), dim = -1)
322330
q = F.pad(q, (0, 1), value = -1.)
323331

324332
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
@@ -509,6 +517,11 @@ def forward(
509517
if self.selective:
510518
sim = selective_attn(sim)
511519

520+
if self.head_learned_sink:
521+
# add learned attention sink
522+
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
523+
sim = cat((attn_sink, sim), dim = -1)
524+
512525
pre_softmax_attn = sim
513526

514527
attn = self.attn_fn(sim)
@@ -517,6 +530,10 @@ def forward(
517530

518531
post_softmax_attn = attn
519532

533+
if self.head_learned_sink:
534+
# remove attention sink
535+
attn = attn[..., 1:]
536+
520537
attn = self.attn_dropout(attn)
521538

522539
if exists(self.post_softmax_talking_heads):

x_transformers/x_transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,7 @@ def __init__(
13191319
value_dim_head = None,
13201320
dim_out = None,
13211321
add_zero_kv = False, # same as add_zero_attn in pytorch
1322+
head_learned_sink = False,
13221323
rotate_num_heads = None,
13231324
data_dependent_alibi = False,
13241325
data_dependent_alibi_per_row = False,
@@ -1515,6 +1516,7 @@ def __init__(
15151516
selective = selective,
15161517
custom_attn_fn = custom_attn_fn,
15171518
add_zero_kv = add_zero_kv,
1519+
head_learned_sink = head_learned_sink,
15181520
flash = flash,
15191521
softclamp_logits = softclamp_logits,
15201522
logit_softclamp_value = logit_softclamp_value,

0 commit comments

Comments
 (0)