|
| 1 | +# Gated DeltaNet for Linear Attention |
| 2 | + |
| 3 | +Recently, [Qwen3-Next](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list) and [Kimi Linear](https://arxiv.org/abs/2510.26692) proposed hybrid transformers that implement alternatives to the attention mechanism that scale linearly instead of quadratically with respect to the context length. |
| 4 | + |
| 5 | +Both Qwen3-Next and Kimi Linear use a 3:1 ratio, meaning for every three transformer blocks employing the linear Gated DeltaNet variant, there’s one block that uses full attention, as shown in the figure below. |
| 6 | + |
| 7 | +<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gated_deltanet/01.webp" alt="Qwen3-Next versus Kimi Linear" style="zoom:47%;" /> |
| 8 | + |
| 9 | + |
| 10 | + |
| 11 | + |
| 12 | + |
| 13 | +## Introduction and Overview |
| 14 | + |
| 15 | +Gated DeltaNet is a linear attention variant with inspiration from recurrent neural networks, including a gating mechanism from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper. In a sense, Gated DeltaNet is a DeltaNet with Mamba-style gating, and DeltaNet is a linear attention mechanism. |
| 16 | + |
| 17 | +Kimi Linear modifies the linear attention mechanism of Qwen3-Next by the Kimi Delta Attention (KDA) mechanism, which is essentially a refinement of Gated DeltaNet. Whereas Qwen3-Next applies a scalar gate (one value per attention head) to control the memory decay rate, Kimi Linear replaces it with a channel-wise gating for each feature dimension. According to the authors, this gives more control over the memory, and this, in turn, improves long-context reasoning. |
| 18 | + |
| 19 | +In addition, for the full attention layers, Kimi Linear replaces Qwen3-Next’s gated attention layers (which are essentially standard multi-head attention layers with output gating) with Multi-Head Latent Attention (MLA). This is the same MLA mechanism we discussed earlier in the DeepSeek V3/R1 section, but with an additional gate. (To recap, MLA compresses the key/value space to reduce the KV cache size.) |
| 20 | + |
| 21 | +The MLA in Kimi Linear does not use the gate, which was intentional so that the authors could compare the architecture more directly to standard MLA, however, they [stated](https://x.com/yzhang_cs/status/1984631714464088563) that they plan to add it in the future. |
| 22 | + |
| 23 | +Since we already implemented MLA in [../05_mla](../05_mla), this bonus material focuses on the Gated DeltaNet aspect. |
| 24 | + |
| 25 | + |
| 26 | + |
| 27 | +## Gated Attention |
| 28 | + |
| 29 | +Before we get to the Gated DeltaNet itself, let's briefly talk about the gate. As you can see in the upper part of the Qwen3-Next architecture in the previous figure, Qwen3-Next uses "gated attention". This is essentially regular full attention with an additional sigmoid gate. |
| 30 | + |
| 31 | +This gating is a simple modification that I added to the `MultiHeadAttention` code from chapter 3 below for illustration purposes: |
| 32 | + |
| 33 | +```python |
| 34 | +import torch |
| 35 | +from torch import nn |
| 36 | + |
| 37 | +class GatedMultiHeadAttention(nn.Module): |
| 38 | + def __init__( |
| 39 | + self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False |
| 40 | + ): |
| 41 | + super().__init__() |
| 42 | + assert d_out % num_heads == 0 |
| 43 | + |
| 44 | + self.d_out = d_out |
| 45 | + self.num_heads = num_heads |
| 46 | + self.head_dim = d_out // num_heads |
| 47 | + |
| 48 | + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) |
| 49 | + #################################################### |
| 50 | + ### NEW: Add gate |
| 51 | + self.W_gate = nn.Linear(d_in, d_out, bias=qkv_bias) |
| 52 | + #################################################### |
| 53 | + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) |
| 54 | + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) |
| 55 | + |
| 56 | + self.out_proj = nn.Linear(d_out, d_out) |
| 57 | + self.dropout = nn.Dropout(dropout) |
| 58 | + |
| 59 | + self.register_buffer( |
| 60 | + "mask", |
| 61 | + torch.triu(torch.ones(context_length, context_length), diagonal=1), |
| 62 | + persistent=False, |
| 63 | + ) |
| 64 | + |
| 65 | + def forward(self, x): |
| 66 | + b, num_tokens, _ = x.shape |
| 67 | + queries = self.W_query(x) |
| 68 | + #################################################### |
| 69 | + ### NEW: Add gate |
| 70 | + gate = self.W_gate(x) |
| 71 | + #################################################### |
| 72 | + keys = self.W_key(x) |
| 73 | + values = self.W_value(x) |
| 74 | + |
| 75 | + keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) |
| 76 | + values = values.view(b, num_tokens, self.num_heads, self.head_dim) |
| 77 | + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) |
| 78 | + |
| 79 | + keys = keys.transpose(1, 2) |
| 80 | + queries = queries.transpose(1, 2) |
| 81 | + values = values.transpose(1, 2) |
| 82 | + |
| 83 | + attn_scores = queries @ keys.transpose(2, 3) |
| 84 | + |
| 85 | + mask_bool = self.mask.bool()[:num_tokens, :num_tokens] |
| 86 | + attn_scores.masked_fill_( |
| 87 | + mask_bool, torch.finfo(attn_scores.dtype).min |
| 88 | + ) |
| 89 | + |
| 90 | + attn_weights = torch.softmax( |
| 91 | + attn_scores / (self.head_dim ** 0.5), dim=-1 |
| 92 | + ) |
| 93 | + attn_weights = self.dropout(attn_weights) |
| 94 | + |
| 95 | + context = (attn_weights @ values).transpose(1, 2) |
| 96 | + context = context.reshape(b, num_tokens, self.d_out) |
| 97 | + |
| 98 | + #################################################### |
| 99 | + ### NEW: Add gate |
| 100 | + context = context * torch.sigmoid(gate) |
| 101 | + #################################################### |
| 102 | + out = self.out_proj(context) |
| 103 | + return out |
| 104 | +``` |
| 105 | + |
| 106 | + |
| 107 | + |
| 108 | +As we can see, after computing attention as usual, the model uses a separate gating signal from the same input, applies a sigmoid to keep it between 0 and 1, and multiplies it with the attention output. This allows the model to scale up or down certain features dynamically. The Qwen3-Next developers [state](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list) that this helps with training stability: |
| 109 | + |
| 110 | +> [...] the attention output gating mechanism helps eliminate issues like Attention Sink and Massive Activation, ensuring numerical stability across the model. |
| 111 | +
|
| 112 | + |
| 113 | + |
| 114 | +## Gated DeltaNet |
| 115 | + |
| 116 | +Now, what is Gated DeltaNet? Gated DeltaNet (short for *Gated Delta Network*) is Qwen3-Next's linear-attention layer, which is intended as an alternative to standard softmax attention. It was adopted from the [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464) paper as mentioned earlier. |
| 117 | + |
| 118 | +Gated DeltaNet was originally proposed as an improved version of Mamba2, where it combines the gated decay mechanism of Mamba2 with a delta rule. |
| 119 | + |
| 120 | +Mamba is a state-space model (an alternative to transformers), a big topic that deserves separate coverage in the future. |
| 121 | + |
| 122 | +The delta rule part refers to computing the difference (delta, Δ) between new and predicted values to update a hidden state that is used as a memory state (more on that later). |
| 123 | + |
| 124 | +(Side note: Readers with classic machine learning literature can think of this as similar to Hebbian learning inspired by biology: "Cells that fire together wire together." It's basically a precursor of the perceptron update rule and gradient descent-based learning, but without supervision.) |
| 125 | + |
| 126 | +Gated DeltaNet has a gate similar to the gate in gated attention discussed earlier, except that it uses a SiLU instead of logistic sigmoid activation, as illustrated below. (The SiLU choice is likely to improve gradient flow and stability over the standard sigmoid.) |
| 127 | + |
| 128 | +<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gated_deltanet/02.webp" alt="Gated DeltaNet" style="zoom:47%;" /> |
| 129 | + |
| 130 | +However, as shown in the figure above, the "gated" in the Gated DeltaNet also refers to several additional gates: |
| 131 | + |
| 132 | +- `α` (decay gate) controls how fast the memory decays or resets over time, |
| 133 | +- `β` (update gate) controls how strongly new inputs modify the state. |
| 134 | + |
| 135 | + |
| 136 | +In code, a simplified version of the Gated DeltaNet depicted above (without the convolutional mixing) can be implemented as follows (the code is inspired by the [official implementation](https://github.com/huggingface/transformers/blob/0ed6d51ae8ed3f4fafca67a983b8d75bc76cd51b/src/transformers/models/qwen3_next/modular_qwen3_next.py#L835) by the Qwen3 team): |
| 137 | + |
| 138 | +```python |
| 139 | +import torch |
| 140 | +from torch import nn |
| 141 | +import torch.nn.functional as F |
| 142 | + |
| 143 | +def l2norm(x, dim=-1, eps=1e-6): |
| 144 | + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) |
| 145 | + |
| 146 | +class GatedDeltaNet(nn.Module): |
| 147 | + def __init__( |
| 148 | + self, d_in, d_out, dropout, num_heads, qkv_bias=False |
| 149 | + ): |
| 150 | + super().__init__() |
| 151 | + assert d_out % num_heads == 0 |
| 152 | + |
| 153 | + self.d_out = d_out |
| 154 | + self.num_heads = num_heads |
| 155 | + self.head_dim = d_out // num_heads |
| 156 | + |
| 157 | + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) |
| 158 | + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) |
| 159 | + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) |
| 160 | + #################################################### |
| 161 | + ### NEW: Gates for delta rule and output gating |
| 162 | + self.W_gate = nn.Linear(d_in, d_out, bias=False) |
| 163 | + self.W_beta = nn.Linear(d_in, d_out, bias=False) |
| 164 | + |
| 165 | + # Note: The decay gate alpha corresponds to |
| 166 | + # A_log + W_alpha(x) + dt_bias |
| 167 | + self.W_alpha = nn.Linear(d_in, num_heads, bias=False) |
| 168 | + self.dt_bias = nn.Parameter(torch.ones(num_heads)) |
| 169 | + self.A_log = nn.Parameter(torch.zeros(num_heads)) |
| 170 | + # We could implement this as |
| 171 | + # W_alpha = nn.Linear(d_in, num_heads, bias=True) |
| 172 | + # but the bias is separate for interpretability and |
| 173 | + # to mimic the official implementation |
| 174 | + |
| 175 | + self.norm = nn.RMSNorm(self.head_dim, eps=1e-6) |
| 176 | + #################################################### |
| 177 | + |
| 178 | + self.out_proj = nn.Linear(d_out, d_out) |
| 179 | + self.dropout = nn.Dropout(dropout) |
| 180 | + |
| 181 | + def forward(self, x): |
| 182 | + b, num_tokens, _ = x.shape |
| 183 | + queries = self.W_query(x) |
| 184 | + keys = self.W_key(x) |
| 185 | + values = self.W_value(x) |
| 186 | + #################################################### |
| 187 | + ### NEW: Compute delta rule gates |
| 188 | + beta = torch.sigmoid(self.W_beta(x)) |
| 189 | + alpha = -self.A_log.exp().view(1, 1, -1) * F.softplus( |
| 190 | + self.W_alpha(x) + self.dt_bias |
| 191 | + ) |
| 192 | + gate = self.W_gate(x) |
| 193 | + #################################################### |
| 194 | + |
| 195 | + keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) |
| 196 | + values = values.view(b, num_tokens, self.num_heads, self.head_dim) |
| 197 | + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) |
| 198 | + beta = beta.view(b, num_tokens, self.num_heads, self.head_dim) |
| 199 | + gate = gate.view(b, num_tokens, self.num_heads, self.head_dim) # NEW |
| 200 | + |
| 201 | + keys = keys.transpose(1, 2) |
| 202 | + queries = queries.transpose(1, 2) |
| 203 | + values = values.transpose(1, 2) |
| 204 | + beta = beta.transpose(1, 2) |
| 205 | + gate = gate.transpose(1, 2) # NEW |
| 206 | + |
| 207 | + #################################################### |
| 208 | + ### NEW: QKNorm-like normalization for delta rule |
| 209 | + queries = l2norm(queries, dim=-1) / (self.head_dim ** 0.5) |
| 210 | + keys = l2norm(keys, dim=-1) |
| 211 | + #################################################### |
| 212 | + |
| 213 | + S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim) |
| 214 | + |
| 215 | + outs = [] |
| 216 | + #################################################### |
| 217 | + ### NEW: Gated delta rule update |
| 218 | + for t in range(num_tokens): |
| 219 | + k_t = keys[:, :, t] |
| 220 | + q_t = queries[:, :, t] |
| 221 | + v_t = values[:, :, t] |
| 222 | + b_t = beta[:, :, t] |
| 223 | + a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1) |
| 224 | + |
| 225 | + S = S * a_t.exp() |
| 226 | + kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2) |
| 227 | + delta = (v_t - kv_mem) * b_t |
| 228 | + S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2) |
| 229 | + y_t = (S * q_t.unsqueeze(-1)).sum(dim=-2) |
| 230 | + #################################################### |
| 231 | + outs.append(y_t) |
| 232 | + |
| 233 | + context = torch.stack(outs, dim=2).transpose(1, 2).contiguous() |
| 234 | + context = context.view(b, num_tokens, self.num_heads, self.head_dim) |
| 235 | + |
| 236 | + #################################################### |
| 237 | + ### NEW: Apply RMSNorm and SiLU gate |
| 238 | + context = self.norm(context) |
| 239 | + context = context * F.silu(gate) |
| 240 | + #################################################### |
| 241 | + |
| 242 | + context = context.view(b, num_tokens, self.d_out) |
| 243 | + context = self.dropout(context) |
| 244 | + out = self.out_proj(context) |
| 245 | + return out |
| 246 | +``` |
| 247 | + |
| 248 | +(Note that for simplicity, I omitted the convolutional mixing that Qwen3-Next and Kimi Linear use to keep the code more readable and focus on the recurrent aspects.) |
| 249 | + |
| 250 | +So, as we can see above, there are lots of differences to standard (or gated) attention. |
| 251 | + |
| 252 | +In gated attention, the model computes normal attention between all tokens (every token attends or looks at every other token). Then, after getting the attention output, a gate (a sigmoid) decides how much of that output to keep. The takeaway is that it's still the the regular scaled-dot product attention that scales quadratically with the context length. |
| 253 | + |
| 254 | +As a refresher, scaled-dot production attention is computed as softmax(QKᵀ)V, where Q and K are *n*-by-*d* matrices, where *n* is the number of input tokens, and *d* is the embedding dimension. So QKᵀ results in an attention *n*-by-*n* matrix, that is multiplied by a *n*-by-*d* dimensional value matrix V: |
| 255 | + |
| 256 | +``` |
| 257 | +attn_scores = queries @ keys.transpose(2, 3) |
| 258 | +
|
| 259 | +mask_bool = self.mask.bool()[:num_tokens, :num_tokens] |
| 260 | +attn_scores.masked_fill_( |
| 261 | + mask_bool, torch.finfo(attn_scores.dtype).min |
| 262 | +) |
| 263 | +
|
| 264 | +attn_weights = torch.softmax( |
| 265 | + attn_scores / (self.head_dim ** 0.5), dim=-1 |
| 266 | +) |
| 267 | +
|
| 268 | +context = (attn_weights @ values).transpose(1, 2) |
| 269 | +context = context.reshape(b, num_tokens, self.d_out) |
| 270 | +``` |
| 271 | + |
| 272 | + |
| 273 | + |
| 274 | +<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gated_deltanet/03.webp" alt="Quadratic attention" style="zoom:67%;" /> |
| 275 | + |
| 276 | +In Gated DeltaNet, there's no *n*-by-*n* attention matrix. Instead, the model processes tokens one by one. It keeps a running memory (a state) that gets updated as each new token comes in. This is what's implemented as, where `S` is the state that gets updated recurrently for each time step *t*. |
| 277 | + |
| 278 | +```python |
| 279 | +S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim) |
| 280 | +outs = [] |
| 281 | + |
| 282 | +for t in range(num_tokens): |
| 283 | + k_t = keys[:, :, t] |
| 284 | + q_t = queries[:, :, t] |
| 285 | + v_t = values[:, :, t] |
| 286 | + b_t = beta[:, :, t] |
| 287 | + a_t = alpha[:, t].unsqueeze(-1).unsqueeze(-1) |
| 288 | + |
| 289 | + S = S * a_t.exp() |
| 290 | + kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2) |
| 291 | + delta = (v_t - kv_mem) * b_t |
| 292 | + S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2) |
| 293 | + y_t = (S * q_t.unsqueeze(-1)).sum(dim=-2) |
| 294 | +``` |
| 295 | + |
| 296 | +And the gates control how that memory changes: |
| 297 | + |
| 298 | +- α (`alpha`) regulates how much of the old memory to forget (decay). |
| 299 | + |
| 300 | +- β (`alpha`) regulates how much the current token at time step *t* updates the memory. |
| 301 | + |
| 302 | +(And the final output gate, not shown in the snippet above, is similar to gated attention; it controls how much of the output is kept.) |
| 303 | + |
| 304 | +So, in a sense, this state update in Gated DeltaNet is similar to how recurrent neural networks (RNNs) work. The advantage is that it scales linearly (via the for-loop) instead of quadratically with context length. |
| 305 | + |
| 306 | +The downside of this recurrent state update is that, compared to regular (or gated) attention, it sacrifices the global context modeling ability that comes from full pairwise attention. |
| 307 | + |
| 308 | +Gated DeltaNet, can, to some extend, still capture context, but it has to go through the memory (*S*) bottleneck. That memory is a fixed size and thus more efficient, but it compresses past context into a single hidden state similar to RNNs. |
| 309 | + |
| 310 | +That's why the Qwen3-Next and Kimi Linear architectures don't replace all attention layers with DeltaNet layers but use the 3:1 ratio mentioned earlier. |
| 311 | + |
| 312 | + |
| 313 | +## DeltaNet Memory Savings |
| 314 | + |
| 315 | +In the previous section, we discussed the advantage of the DeltaNet over full attention in terms of linear instead of quadratic compute complexity with respect to the context length. |
| 316 | + |
| 317 | +Next to the linear compute complexity, another big advantage of DeltaNet is the memory savings, as DeltaNet modules don't grow the KV cache. (For more information about KV caching, see [../03_kv-cache](../03_kv-cache)). Instead, as mentioned earlier, they keep a fixed-size recurrent state, so memory stays constant with context length. |
| 318 | + |
| 319 | +For a regular multi-head attention (MHA) layer, we can compute the KV cache size as follows: |
| 320 | + |
| 321 | +``` |
| 322 | +KV_cache_MHA ≈ batch_size × n_tokens × n_heads × d_head × 2 × bytes |
| 323 | +``` |
| 324 | + |
| 325 | +(The 2 multiplier is there because we have both keys and values that we store in the cache.) |
| 326 | + |
| 327 | +For the simplified DeltaNet version implemented above, we have: |
| 328 | + |
| 329 | + |
| 330 | +``` |
| 331 | +KV_cache_DeltaNet = batch_size × n_heads × d_head × d_head × bytes |
| 332 | +``` |
| 333 | + |
| 334 | +Note that the `KV_cache_DeltaNet` memory size doesn't have a context length (`n_tokens`) dependency. Also, we have only the memory state S that we store instead of separate keys and values, hence `2 × bytes` becomes just `bytes`. However, note that we now have a quadratic `n_heads × d_head` in here. This comes from the state : |
| 335 | + |
| 336 | +``` |
| 337 | +S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim) |
| 338 | +``` |
| 339 | + |
| 340 | +But that's usually nothing to worry about, as the head dimension is usually relatively small. For instance, it's 128 in Qwen3-Next. |
| 341 | + |
| 342 | +The full version with the convolutional mixing is a bit more complex, including the kernel size and so on, but the formulas above should illustrate the main trend and motivation behind the Gated DeltaNet. |
| 343 | + |
| 344 | +We can visualize the memory estimates and savings for different context lengths via the following helper script: |
| 345 | + |
| 346 | +```bash |
| 347 | +uv run plot_memory_estimates_gated_deltanet.py \ |
| 348 | + --emb_dim 2048 \ |
| 349 | + --n_heads 16 \ |
| 350 | + --n_layers 48 \ |
| 351 | + --dtype "bf16" |
| 352 | +``` |
| 353 | + |
| 354 | +Note that the above computes the `head_dim` as `emb_dim / n_heads`. I.e., 2048 / 16 = 128. |
| 355 | + |
| 356 | +<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gated_deltanet/plot.webp" alt="Gated DeltaNet scaling" style="zoom:47%;" /> |
0 commit comments