Skip to content

johanwind/wind_rwkv

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

35 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Wind RWKV

A repository with optimized kernels for RWKV language models. Currently focused on RWKV-7.

Kernel benchmarks for RWKV-7

The kernels were timed using tests/speed_test.py with modeldim 4096 and varying (batch size, head size, sequence length) as labeled in the table.

H100

Kernel (8,64,4096) (8,128,4096) (8,256,4096) (1,256,32768) Peak VRAM1 Typical error
Chunked bf16 8 ms 11 ms 54 ms 224 ms 5 - 8 GB 5e-3
Backstepping fp32 longhead 23 ms 46 ms 80 ms 124 ms 8 - 14 GB 9e-5
Backstepping fp32 smallhead 17 ms 101 ms 862 ms 1802 ms 7 - 13 GB 9e-5
Triton bighead fp32 66 ms 87 ms 168 ms 1175 ms 6 - 12 GB 5e-5
Triton bighead bf16 2 29 ms 59 ms 358 ms 6 - 12 GB 5e-3
FLA chunk_rwkv7 64 ms 62 ms 89 ms 93 ms 12 - 13 GB 4e-3

MI300X

Kernel (8,64,4096) (8,128,4096) (8,256,4096) (1,256,32768) Peak VRAM1 Typical error
Backstepping fp32 longhead 29 ms 39 ms 75 ms 162 ms 8 - 14 GB 9e-5
Backstepping fp32 smallhead 251 ms 757 ms 2706 ms 15025 ms 7 - 13 GB 9e-5
Triton bighead fp32 67 ms 100 ms 287 ms 2073 ms 6 - 12 GB 5e-5
Triton bighead bf16 42 ms 72 ms 198 ms 1453 ms 6 - 12 GB 5e-3
FLA chunk_rwkv7 52 ms 61 ms 98 ms 202 ms 12 - 13 GB 4e-3

Kernel descriptions

The RWKV-7 kernels all compute the following:

def naive(r,w,k,v,a,b,s):
    y = th.empty_like(v)
    for t in range(w.shape[1]):
        s = s * th.exp(-th.exp(w[:,t,:,None,:])) + s @ a[:,t,:,:,None] * b[:,t,:,None,:] + v[:,t,:,:,None] * k[:,t,:,None,:]
        y[:,t,:,:,None] = s @ r[:,t,:,:,None]
    return y, s

Here r,w,k,v,a and b have shape [batch size, sequence length, num heads, head size], while the initial state s has shape [batch size, num heads, head size, head size]. All inputs and outputs are bfloat16 precision.

This is the fastest kernel when applicable. It processes the sequence in chunks of length 16 (chunked formulation) and uses Ampere (CUDA SM80+, i.e., A100 and later) instructions for fast bfloat16 matmuls.

This is essentially the official kernel which was used to train the RWKV-7 World models. Calculates gradients by iterating the state backwards in time (max 15 steps). This makes the code simple, but requires 32-bit floats and limits the decay to ca. 0.5.

Backstepping fp32 smallhead becomes very slow for large head sizes, since the full state is kept in registers, which overflow into global memory. To fix this, backstepping fp32 longhead uses the observation that the columns of the state are essentially updated independently. So it processes blocks of 64 or 32 columns indepdently. This increasing parallelization, and keeps less state in shared memory at a time, while keeping most of the simplicity of backstepping fp32 smallhead.

A simple chunked kernel written in triton. The kernel stores intermediate states in global memory instead of shared memory, so it handles large head sizes (like 1024) without crashing. It takes a flag to choose fp32 or bf16 precision3 which affects all matmuls inside the triton kernel.

RWKV-7 triton kernel from Flash Linear Attention. Chunked implementation with partial sequence length parallelization.

Footnotes

  1. Smallest peak VRAM was typically for (8,64,4096) and largest for (8,256,4096). 2

  2. Triton fails to compile the kernel, only seen on H100.

  3. The kernel also supports tf32 precision for matmuls, but tf32 seems to run into bugs in the triton language, so I didn't expose it.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published