Skip to content

Commit c6b8332

Browse files
authored
Gated DeltaNet write-up (#901)
* Gated DeltaNet write-up * Add copyright and source information to script Added copyright notice and source information. * Remove unused import of Path in plot_memory_estimates * Fix url
1 parent d6c3990 commit c6b8332

File tree

5 files changed

+460
-0
lines changed

5 files changed

+460
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ ch04/04_gqa/kv_bytes_vs_context_length.pdf
1515
ch04/05_mla/kv_bytes_vs_context_length.pdf
1616
ch04/06_swa/kv_bytes_vs_context_length.pdf
1717
ch04/07_moe/ffn_vs_moe.pdf
18+
ch04/08_deltanet/deltanet_memory_plot.pdf
1819

1920
ch05/01_main-chapter-code/loss-plot.pdf
2021
ch05/01_main-chapter-code/temperature-plot.pdf
@@ -29,6 +30,7 @@ ch07/01_main-chapter-code/loss-plot-baseline.pdf
2930
ch07/01_main-chapter-code/loss-plot-mask-instructions.pdf
3031
ch07/01_main-chapter-code/loss-plot-phi3-prompt.pdf
3132
ch07/01_main-chapter-code/loss-plot-alpaca52k.pdf
33+
ch07/04_preference-tuning-with-dpo/reward margins-plot.pdf
3234

3335
# Checkpoint files
3436
appendix-A/01_main-chapter-code/model.pth

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ Several folders contain optional materials as a bonus for interested readers:
172172
- [Grouped-Query Attention](ch04/04_gqa)
173173
- [Multi-Head Latent Attention](ch04/05_mla)
174174
- [Sliding Window Attention](ch04/06_swa)
175+
- [Gated DeltaNet](ch04/08_deltanet)
175176
- [Mixture-of-Experts (MoE)](ch04/07_moe)
176177
- **Chapter 5: Pretraining on unlabeled data:**
177178
- [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/)

ch04/08_deltanet/README.md

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
# Gated DeltaNet for Linear Attention
2+
3+
Recently, [Qwen3-Next](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list) and [Kimi Linear](https://arxiv.org/abs/2510.26692) proposed hybrid transformers that implement alternatives to the attention mechanism that scale linearly instead of quadratically with respect to the context length.
4+
5+
Both Qwen3-Next and Kimi Linear use a 3:1 ratio, meaning for every three transformer blocks employing the linear Gated DeltaNet variant, there’s one block that uses full attention, as shown in the figure below.
6+
7+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gated_deltanet/01.webp" alt="Qwen3-Next versus Kimi Linear" style="zoom:47%;" />
8+
9+
10+
11+
&nbsp;
12+
13+
## Introduction and Overview
14+
15+
Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism.
16+
17+
Kimi Linear modifies the linear attention mechanism of Qwen3-Next by the Kimi Delta Attention (KDA) mechanism, which is essentially a refinement of Gated DeltaNet. Whereas Qwen3-Next applies a scalar gate (one value per attention head) to control the memory decay rate, Kimi Linear replaces it with a channel-wise gating for each feature dimension. According to the authors, this gives more control over the memory, and this, in turn, improves long-context reasoning.
18+
19+
In addition, for the full attention layers, Kimi Linear replaces Qwen3-Next’s gated attention layers (which are essentially standard multi-head attention layers with output gating) with Multi-Head Latent Attention (MLA). This is the same MLA mechanism we discussed earlier in the DeepSeek V3/R1 section, but with an additional gate. (To recap, MLA compresses the key/value space to reduce the KV cache size.)
20+
21+
The MLA in Kimi Linear does not use the gate, which was intentional so that the authors could compare the architecture more directly to standard MLA, however, they [stated](https://x.com/yzhang_cs/status/1984631714464088563) that they plan to add it in the future.
22+
23+
Since we already implemented MLA in [../05_mla](../05_mla), this bonus material focuses on the Gated DeltaNet aspect.
24+
25+
26+
&nbsp;
27+
## Gated Attention
28+
29+
Before we get to the Gated DeltaNet itself, let's briefly talk about the gate. As you can see in the upper part of the Qwen3-Next architecture in the previous figure, Qwen3-Next uses "gated attention". This is essentially regular full attention with an additional sigmoid gate.
30+
31+
This gating is a simple modification that I added to the `MultiHeadAttention` code from chapter 3 below for illustration purposes:
32+
33+
```python
34+
import torch
35+
from torch import nn
36+
37+
class GatedMultiHeadAttention(nn.Module):
38+
def __init__(
39+
self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False
40+
):
41+
super().__init__()
42+
assert d_out % num_heads == 0
43+
44+
self.d_out = d_out
45+
self.num_heads = num_heads
46+
self.head_dim = d_out // num_heads
47+
48+
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
49+
####################################################
50+
### NEW: Add gate
51+
self.W_gate = nn.Linear(d_in, d_out, bias=qkv_bias)
52+
####################################################
53+
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
54+
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
55+
56+
self.out_proj = nn.Linear(d_out, d_out)
57+
self.dropout = nn.Dropout(dropout)
58+
59+
self.register_buffer(
60+
"mask",
61+
torch.triu(torch.ones(context_length, context_length), diagonal=1),
62+
persistent=False,
63+
)
64+
65+
def forward(self, x):
66+
b, num_tokens, _ = x.shape
67+
queries = self.W_query(x)
68+
####################################################
69+
### NEW: Add gate
70+
gate = self.W_gate(x)
71+
####################################################
72+
keys = self.W_key(x)
73+
values = self.W_value(x)
74+
75+
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
76+
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
77+
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
78+
79+
keys = keys.transpose(1, 2)
80+
queries = queries.transpose(1, 2)
81+
values = values.transpose(1, 2)
82+
83+
attn_scores = queries @ keys.transpose(2, 3)
84+
85+
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
86+
attn_scores.masked_fill_(
87+
mask_bool, torch.finfo(attn_scores.dtype).min
88+
)
89+
90+
attn_weights = torch.softmax(
91+
attn_scores / (self.head_dim ** 0.5), dim=-1
92+
)
93+
attn_weights = self.dropout(attn_weights)
94+
95+
context = (attn_weights @ values).transpose(1, 2)
96+
context = context.reshape(b, num_tokens, self.d_out)
97+
98+
####################################################
99+
### NEW: Add gate
100+
context = context * torch.sigmoid(gate)
101+
####################################################
102+
out = self.out_proj(context)
103+
return out
104+
```
105+
106+
107+
108+
As we can see, after computing attention as usual, the model uses a separate gating signal from the same input, applies a sigmoid to keep it between 0 and 1, and multiplies it with the attention output. This allows the model to scale up or down certain features dynamically. The Qwen3-Next developers [state](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list) that this helps with training stability:
109+
110+
> [...] the attention output gating mechanism helps eliminate issues like Attention Sink and Massive Activation, ensuring numerical stability across the model.
111+
112+
113+
&nbsp;
114+
## Gated DeltaNet
115+
116+
Now, what is Gated DeltaNet? Gated DeltaNet (short for *Gated Delta Network*) is Qwen3-Next's linear-attention layer, which is intended as an alternative to standard softmax attention. It was adopted from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper as mentioned earlier.
117+
118+
Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule.
119+
120+
Mamba is a state-space model (an alternative to transformers), a big topic that deserves separate coverage in the future.
121+
122+
The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later).
123+
124+
(Side note: Readers with classic machine learning literature can think of this as similar to Hebbian learning inspired by biology: "Cells that fire together wire together." It's basically a precursor of the perceptron update rule and gradient descent-based learning, but without supervision.)
125+
126+
Gated DeltaNet has a gate similar to the gate in gated attention discussed earlier, except that it uses a SiLU instead of logistic sigmoid activation, as illustrated below. (The SiLU choice is likely to improve gradient flow and stability over the standard sigmoid.)
127+
128+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gated_deltanet/02.webp" alt="Gated DeltaNet" style="zoom:47%;" />
129+
130+
However, as shown in the figure above, the "gated" in the Gated DeltaNet also refers to several additional gates:
131+
132+
- `α` (decay gate) controls how fast the memory decays or resets over time,
133+
- `β` (update gate) controls how strongly new inputs modify the state.
134+
135+
136+
In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team):
137+
138+
```python
139+
import torch
140+
from torch import nn
141+
import torch.nn.functional as F
142+
143+
def l2norm(x, dim=-1, eps=1e-6):
144+
return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
145+
146+
class GatedDeltaNet(nn.Module):
147+
def __init__(
148+
self, d_in, d_out, dropout, num_heads, qkv_bias=False
149+
):
150+
super().__init__()
151+
assert d_out % num_heads == 0
152+
153+
self.d_out = d_out
154+
self.num_heads = num_heads
155+
self.head_dim = d_out // num_heads
156+
157+
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
158+
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
159+
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
160+
####################################################
161+
### NEW: Gates for delta rule and output gating
162+
self.W_gate = nn.Linear(d_in, d_out, bias=False)
163+
self.W_beta = nn.Linear(d_in, d_out, bias=False)
164+
165+
# Note: The decay gate alpha corresponds to
166+
# A_log + W_alpha(x) + dt_bias
167+
self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
168+
self.dt_bias = nn.Parameter(torch.ones(num_heads))
169+
self.A_log = nn.Parameter(torch.zeros(num_heads))
170+
# We could implement this as
171+
# W_alpha = nn.Linear(d_in, num_heads, bias=True)
172+
# but the bias is separate for interpretability and
173+
# to mimic the official implementation
174+
175+
self.norm = nn.RMSNorm(self.head_dim, eps=1e-6)
176+
####################################################
177+
178+
self.out_proj = nn.Linear(d_out, d_out)
179+
self.dropout = nn.Dropout(dropout)
180+
181+
def forward(self, x):
182+
b, num_tokens, _ = x.shape
183+
queries = self.W_query(x)
184+
keys = self.W_key(x)
185+
values = self.W_value(x)
186+
####################################################
187+
### NEW: Compute delta rule gates
188+
beta = torch.sigmoid(self.W_beta(x))
189+
alpha = -self.A_log.exp().view(1, 1, -1) * F.softplus(
190+
self.W_alpha(x) + self.dt_bias
191+
)
192+
gate = self.W_gate(x)
193+
####################################################
194+
195+
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
196+
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
197+
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
198+
beta = beta.view(b, num_tokens, self.num_heads, self.head_dim)
199+
gate = gate.view(b, num_tokens, self.num_heads, self.head_dim) # NEW
200+
201+
keys = keys.transpose(1, 2)
202+
queries = queries.transpose(1, 2)
203+
values = values.transpose(1, 2)
204+
beta = beta.transpose(1, 2)
205+
gate = gate.transpose(1, 2) # NEW
206+
207+
####################################################
208+
### NEW: QKNorm-like normalization for delta rule
209+
queries = l2norm(queries, dim=-1) / (self.head_dim ** 0.5)
210+
keys = l2norm(keys, dim=-1)
211+
####################################################
212+
213+
S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)
214+
215+
outs = []
216+
####################################################
217+
### NEW: Gated delta rule update
218+
for t in range(num_tokens):
219+
k_t = keys[:, :, t]
220+
q_t = queries[:, :, t]
221+
v_t = values[:, :, t]
222+
b_t = beta[:, :, t]
223+
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
224+
225+
S = S * a_t.exp()
226+
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
227+
delta = (v_t - kv_mem) * b_t
228+
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
229+
y_t = (S * q_t.unsqueeze(-1)).sum(dim=-2)
230+
####################################################
231+
outs.append(y_t)
232+
233+
context = torch.stack(outs, dim=2).transpose(1, 2).contiguous()
234+
context = context.view(b, num_tokens, self.num_heads, self.head_dim)
235+
236+
####################################################
237+
### NEW: Apply RMSNorm and SiLU gate
238+
context = self.norm(context)
239+
context = context * F.silu(gate)
240+
####################################################
241+
242+
context = context.view(b, num_tokens, self.d_out)
243+
context = self.dropout(context)
244+
out = self.out_proj(context)
245+
return out
246+
```
247+
248+
(Note that for simplicity, I omitted the convolutional mixing that Qwen3-Next and Kimi Linear use to keep the code more readable and focus on the recurrent aspects.)
249+
250+
So, as we can see above, there are lots of differences to standard (or gated) attention.
251+
252+
In gated attention, the model computes normal attention between all tokens (every token attends or looks at every other token). Then, after getting the attention output, a gate (a sigmoid) decides how much of that output to keep. The takeaway is that it's still the the regular scaled-dot product attention that scales quadratically with the context length.
253+
254+
As a refresher, scaled-dot production attention is computed as softmax(QKᵀ)V, where Q and K are *n*-by-*d* matrices, where *n* is the number of input tokens, and *d* is the embedding dimension. So QKᵀ results in an attention *n*-by-*n* matrix, that is multiplied by a *n*-by-*d* dimensional value matrix V:
255+
256+
```
257+
attn_scores = queries @ keys.transpose(2, 3)
258+
259+
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
260+
attn_scores.masked_fill_(
261+
mask_bool, torch.finfo(attn_scores.dtype).min
262+
)
263+
264+
attn_weights = torch.softmax(
265+
attn_scores / (self.head_dim ** 0.5), dim=-1
266+
)
267+
268+
context = (attn_weights @ values).transpose(1, 2)
269+
context = context.reshape(b, num_tokens, self.d_out)
270+
```
271+
272+
273+
274+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gated_deltanet/03.webp" alt="Quadratic attention" style="zoom:67%;" />
275+
276+
In Gated DeltaNet, there's no *n*-by-*n* attention matrix. Instead, the model processes tokens one by one. It keeps a running memory (a state) that gets updated as each new token comes in. This is what's implemented as, where `S` is the state that gets updated recurrently for each time step *t*.
277+
278+
```python
279+
S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)
280+
outs = []
281+
282+
for t in range(num_tokens):
283+
k_t = keys[:, :, t]
284+
q_t = queries[:, :, t]
285+
v_t = values[:, :, t]
286+
b_t = beta[:, :, t]
287+
a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1)
288+
289+
S = S * a_t.exp()
290+
kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)
291+
delta = (v_t - kv_mem) * b_t
292+
S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
293+
y_t = (S * q_t.unsqueeze(-1)).sum(dim=-2)
294+
```
295+
296+
And the gates control how that memory changes:
297+
298+
- α (`alpha`) regulates how much of the old memory to forget (decay).
299+
300+
- β (`alpha`) regulates how much the current token at time step *t* updates the memory.
301+
302+
(And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.)
303+
304+
So, in a sense, this state update in Gated DeltaNet is similar to how recurrent neural networks (RNNs) work. The advantage is that it scales linearly (via the for-loop) instead of quadratically with context length.
305+
306+
The downside of this recurrent state update is that, compared to regular (or gated) attention, it sacrifices the global context modeling ability that comes from full pairwise attention.
307+
308+
Gated DeltaNet, can, to some extend, still capture context, but it has to go through the memory (*S*) bottleneck. That memory is a fixed size and thus more efficient, but it compresses past context into a single hidden state similar to RNNs.
309+
310+
That's why the Qwen3-Next and Kimi Linear architectures don't replace all attention layers with DeltaNet layers but use the 3:1 ratio mentioned earlier.
311+
312+
&nbsp;
313+
## DeltaNet Memory Savings
314+
315+
In the previous section, we discussed the advantage of the DeltaNet over full attention in terms of linear instead of quadratic compute complexity with respect to the context length.
316+
317+
Next to the linear compute complexity, another big advantage of DeltaNet is the memory savings, as DeltaNet modules don't grow the KV cache. (For more information about KV caching, see [../03_kv-cache](../03_kv-cache)). Instead, as mentioned earlier, they keep a fixed-size recurrent state, so memory stays constant with context length.
318+
319+
For a regular multi-head attention (MHA) layer, we can compute the KV cache size as follows:
320+
321+
```
322+
KV_cache_MHA ≈ batch_size × n_tokens × n_heads × d_head × 2 × bytes
323+
```
324+
325+
(The 2 multiplier is there because we have both keys and values that we store in the cache.)
326+
327+
For the simplified DeltaNet version implemented above, we have:
328+
329+
330+
```
331+
KV_cache_DeltaNet = batch_size × n_heads × d_head × d_head × bytes
332+
```
333+
334+
Note that the `KV_cache_DeltaNet` memory size doesn't have a context length (`n_tokens`) dependency. Also, we have only the memory state S that we store instead of separate keys and values, hence `2 × bytes` becomes just `bytes`. However, note that we now have a quadratic `n_heads × d_head` in here. This comes from the state :
335+
336+
```
337+
S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)
338+
```
339+
340+
But that's usually nothing to worry about, as the head dimension is usually relatively small. For instance, it's 128 in Qwen3-Next.
341+
342+
The full version with the convolutional mixing is a bit more complex, including the kernel size and so on, but the formulas above should illustrate the main trend and motivation behind the Gated DeltaNet.
343+
344+
We can visualize the memory estimates and savings for different context lengths via the following helper script:
345+
346+
```bash
347+
uv run plot_memory_estimates_gated_deltanet.py \
348+
--emb_dim 2048 \
349+
--n_heads 16 \
350+
--n_layers 48 \
351+
--dtype "bf16"
352+
```
353+
354+
Note that the above computes the `head_dim` as `emb_dim / n_heads`. I.e., 2048 / 16 = 128.
355+
356+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gated_deltanet/plot.webp" alt="Gated DeltaNet scaling" style="zoom:47%;" />

0 commit comments

Comments
 (0)