1+ import  torch 
2+ import  numpy  as  np 
3+ from  lightning_attn2  import  lightning_attn2 
4+ 
5+ def  benchmark_attention (configurations ):    
6+     for  B , H , N , D  in  configurations :
7+         print ("="  *  60 )
8+         print (f"Timing forward pass for B={ B } { H } { N } { D }  )
9+ 
10+         # Initialize input tensors 
11+         q  =  torch .randn (B , H , N , D , dtype = torch .bfloat16 , device = 'cuda' ).contiguous ()
12+         k  =  torch .randn (B , H , N , D , dtype = torch .bfloat16 , device = 'cuda' ).contiguous ()
13+         v  =  torch .randn (B , H , N , D , dtype = torch .bfloat16 , device = 'cuda' ).contiguous ()
14+         s  =  torch .rand (H , dtype = torch .float32 , device = 'cuda' ).contiguous ()
15+ 
16+         # Prepare timing events 
17+         start_events  =  [torch .cuda .Event (enable_timing = True ) for  _  in  range (10 )]
18+         end_events  =  [torch .cuda .Event (enable_timing = True ) for  _  in  range (10 )]
19+         
20+         torch .cuda .empty_cache ()
21+         torch .cuda .synchronize ()
22+         
23+         # Warmup 
24+         print ("Warming up..." )
25+         for  _  in  range (10 ):
26+             _  =  lightning_attn2 (q , k , v , s )
27+             
28+         # Benchmark runs 
29+         print ("Running benchmarks..." )
30+         for  i  in  range (10 ):          
31+             start_events [i ].record ()
32+             _  =  lightning_attn2 (q , k , v , s )
33+             end_events [i ].record ()
34+ 
35+         torch .cuda .synchronize ()
36+         
37+         # Calculate timing statistics 
38+         times  =  [s .elapsed_time (e ) for  s , e  in  zip (start_events , end_events )]
39+         time_us  =  np .mean (times ) *  1000   # convert to microseconds 
40+         time_std  =  np .std (times ) *  1000 
41+ 
42+         print (f"Average latency: { time_us :.2f} { time_std :.2f}  )
43+         print ("-"  *  60 )
44+         
45+         torch .cuda .empty_cache ()
46+         torch .cuda .synchronize ()
47+ 
48+ if  __name__  ==  "__main__" :
49+     configurations  =  [
50+         (1 , 8 , 1024 ,  128 ),
51+         (1 , 8 , 2048 ,  128 ),
52+         (1 , 8 , 4096 ,  128 ),
53+         (1 , 8 , 8192 ,  128 ),
54+         (1 , 8 , 16384 , 128 )
55+     ]
56+ 
57+     print ("Linear Attention Benchmark" )
58+     print ("="  *  60 )
59+     
60+     try :
61+         benchmark_attention (configurations )
62+         print ("\n Benchmark complete!" )
63+     except  RuntimeError  as  e :
64+         if  "out of memory"  in  str (e ):
65+             print (f"\n Out of memory error. Try reducing batch size or sequence length." )
66+         else :
67+             print (f"\n Error during benchmark: { str (e )}  )
0 commit comments