Skip to content

Commit f8e7d01

Browse files
vanilla linear attention
1 parent 6c27e28 commit f8e7d01

File tree

6 files changed

+1225
-0
lines changed

6 files changed

+1225
-0
lines changed

kernels/linear_attention/Makefile

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# GPU Selection: 4090, A100, H100
2+
GPU_TARGET=H100
3+
4+
# Compiler
5+
NVCC=nvcc
6+
7+
NVCCFLAGS=-DNDEBUG -Xcompiler=-fPIE -Xcompiler -fopenmp --expt-extended-lambda --expt-relaxed-constexpr -Xcompiler=-Wno-psabi -Xcompiler=-fno-strict-aliasing --use_fast_math -forward-unknown-to-host-compiler -O3 -Xnvlink=--verbose -Xptxas=--verbose -Xptxas=--warn-on-spills -std=c++20 -MD -MT -MF -x cu -lrt -lpthread -ldl -DKITTENS_HOPPER -arch=sm_90a -lcuda -lcudadevrt -lcudart_static -lcublas -lgomp -I${THUNDERKITTENS_ROOT}/include -I${THUNDERKITTENS_ROOT}/prototype # H100
8+
TARGET=linear_attn
9+
SRC=la.cu
10+
11+
# Default target
12+
all: $(TARGET)
13+
14+
$(TARGET): $(SRC)
15+
$(NVCC) $(SRC) $(NVCCFLAGS) -o $(TARGET)
16+
17+
# Clean target
18+
clean:
19+
rm -f $(TARGET)
20+
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
import numpy as np
3+
from lightning_attn2 import lightning_attn2
4+
5+
def benchmark_attention(configurations):
6+
for B, H, N, D in configurations:
7+
print("=" * 60)
8+
print(f"Timing forward pass for B={B}, H={H}, N={N}, D={D}")
9+
10+
# Initialize input tensors
11+
q = torch.randn(B, H, N, D, dtype=torch.bfloat16, device='cuda').contiguous()
12+
k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device='cuda').contiguous()
13+
v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device='cuda').contiguous()
14+
s = torch.rand(H, dtype=torch.float32, device='cuda').contiguous()
15+
16+
# Prepare timing events
17+
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(10)]
18+
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(10)]
19+
20+
torch.cuda.empty_cache()
21+
torch.cuda.synchronize()
22+
23+
# Warmup
24+
print("Warming up...")
25+
for _ in range(10):
26+
_ = lightning_attn2(q, k, v, s)
27+
28+
# Benchmark runs
29+
print("Running benchmarks...")
30+
for i in range(10):
31+
start_events[i].record()
32+
_ = lightning_attn2(q, k, v, s)
33+
end_events[i].record()
34+
35+
torch.cuda.synchronize()
36+
37+
# Calculate timing statistics
38+
times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
39+
time_us = np.mean(times) * 1000 # convert to microseconds
40+
time_std = np.std(times) * 1000
41+
42+
print(f"Average latency: {time_us:.2f}us (std: {time_std:.2f}us)")
43+
print("-" * 60)
44+
45+
torch.cuda.empty_cache()
46+
torch.cuda.synchronize()
47+
48+
if __name__ == "__main__":
49+
configurations = [
50+
(1, 8, 1024, 128),
51+
(1, 8, 2048, 128),
52+
(1, 8, 4096, 128),
53+
(1, 8, 8192, 128),
54+
(1, 8, 16384, 128)
55+
]
56+
57+
print("Linear Attention Benchmark")
58+
print("=" * 60)
59+
60+
try:
61+
benchmark_attention(configurations)
62+
print("\nBenchmark complete!")
63+
except RuntimeError as e:
64+
if "out of memory" in str(e):
65+
print(f"\nOut of memory error. Try reducing batch size or sequence length.")
66+
else:
67+
print(f"\nError during benchmark: {str(e)}")

0 commit comments

Comments
 (0)