A repository with optimized kernels for RWKV language models. Currently focused on RWKV-7.
The kernels were timed using tests/speed_test.py with modeldim 4096 and varying (batch size, head size, sequence length) as labeled in the table.
Kernel | (8,64,4096) | (8,128,4096) | (8,256,4096) | (1,256,32768) | Peak VRAM1 | Typical error |
---|---|---|---|---|---|---|
Chunked bf16 | 8 ms | 11 ms | 54 ms | 224 ms | 5 - 8 GB | 5e-3 |
Backstepping fp32 longhead | 23 ms | 46 ms | 80 ms | 124 ms | 8 - 14 GB | 9e-5 |
Backstepping fp32 smallhead | 17 ms | 101 ms | 862 ms | 1802 ms | 7 - 13 GB | 9e-5 |
Triton bighead fp32 | 66 ms | 87 ms | 168 ms | 1175 ms | 6 - 12 GB | 5e-5 |
Triton bighead bf16 | 2 | 29 ms | 59 ms | 358 ms | 6 - 12 GB | 5e-3 |
FLA chunk_rwkv7 | 64 ms | 62 ms | 89 ms | 93 ms | 12 - 13 GB | 4e-3 |
Kernel | (8,64,4096) | (8,128,4096) | (8,256,4096) | (1,256,32768) | Peak VRAM1 | Typical error |
---|---|---|---|---|---|---|
Backstepping fp32 longhead | 29 ms | 39 ms | 75 ms | 162 ms | 8 - 14 GB | 9e-5 |
Backstepping fp32 smallhead | 251 ms | 757 ms | 2706 ms | 15025 ms | 7 - 13 GB | 9e-5 |
Triton bighead fp32 | 67 ms | 100 ms | 287 ms | 2073 ms | 6 - 12 GB | 5e-5 |
Triton bighead bf16 | 42 ms | 72 ms | 198 ms | 1453 ms | 6 - 12 GB | 5e-3 |
FLA chunk_rwkv7 | 52 ms | 61 ms | 98 ms | 202 ms | 12 - 13 GB | 4e-3 |
The RWKV-7 kernels all compute the following:
def naive(r,w,k,v,a,b,s):
y = th.empty_like(v)
for t in range(w.shape[1]):
s = s * th.exp(-th.exp(w[:,t,:,None,:])) + s @ a[:,t,:,:,None] * b[:,t,:,None,:] + v[:,t,:,:,None] * k[:,t,:,None,:]
y[:,t,:,:,None] = s @ r[:,t,:,:,None]
return y, s
Here r
,w
,k
,v
,a
and b
have shape [batch size, sequence length, num heads, head size], while the initial state s
has shape [batch size, num heads, head size, head size]. All inputs and outputs are bfloat16 precision.
This is the fastest kernel when applicable. It processes the sequence in chunks of length 16 (chunked formulation) and uses Ampere (CUDA SM80+, i.e., A100 and later) instructions for fast bfloat16 matmuls.
This is essentially the official kernel which was used to train the RWKV-7 World models. Calculates gradients by iterating the state backwards in time (max 15 steps). This makes the code simple, but requires 32-bit floats and limits the decay to ca. 0.5.
Backstepping fp32 smallhead becomes very slow for large head sizes, since the full state is kept in registers, which overflow into global memory. To fix this, backstepping fp32 longhead uses the observation that the columns of the state are essentially updated independently. So it processes blocks of 64 or 32 columns indepdently. This increasing parallelization, and keeps less state in shared memory at a time, while keeping most of the simplicity of backstepping fp32 smallhead.
A simple chunked kernel written in triton. The kernel stores intermediate states in global memory instead of shared memory, so it handles large head sizes (like 1024) without crashing. It takes a flag to choose fp32 or bf16 precision3 which affects all matmuls inside the triton kernel.
RWKV-7 triton kernel from Flash Linear Attention. Chunked implementation with partial sequence length parallelization.