1414
1515import vllm ._custom_ops as ops
1616from vllm .model_executor .layers .layernorm import RMSNorm
17+ from vllm .model_executor .layers .quantization .utils .fp8_utils import (
18+ per_token_group_quant_fp8 ,
19+ )
1720
1821
1922@dataclass
@@ -22,13 +25,15 @@ class bench_params_t:
2225 hidden_size : int
2326 add_residual : bool
2427 dtype : torch .dtype
28+ group_size : list [int ]
2529
2630 def description (self ):
2731 return (
2832 f"N { self .num_tokens } "
2933 f"x D { self .hidden_size } "
3034 f"x R { self .add_residual } "
3135 f"x DT { self .dtype } "
36+ f"x GS { self .group_size } "
3237 )
3338
3439
@@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]:
3843 HIDDEN_SIZES = list (range (1024 , 8129 , 1024 ))
3944 ADD_RESIDUAL = [True , False ]
4045 DTYPES = [torch .bfloat16 , torch .float ]
46+ GROUP_SIZES = [[1 , 64 ], [1 , 128 ]]
4147
42- combinations = product (NUM_TOKENS , HIDDEN_SIZES , ADD_RESIDUAL , DTYPES )
48+ combinations = product (NUM_TOKENS , HIDDEN_SIZES , ADD_RESIDUAL , DTYPES , GROUP_SIZES )
4349 bench_params = list (
44- map (lambda x : bench_params_t (x [0 ], x [1 ], x [2 ], x [3 ]), combinations )
50+ map (lambda x : bench_params_t (x [0 ], x [1 ], x [2 ], x [3 ], x [ 4 ] ), combinations )
4551 )
4652 return bench_params
4753
@@ -52,6 +58,7 @@ def unfused_int8_impl(
5258 x : torch .Tensor ,
5359 residual : torch .Tensor | None ,
5460 quant_dtype : torch .dtype ,
61+ group_size : list [int ],
5562):
5663 # Norm
5764 torch_out = None
@@ -69,6 +76,7 @@ def unfused_fp8_impl(
6976 x : torch .Tensor ,
7077 residual : torch .Tensor | None ,
7178 quant_dtype : torch .dtype ,
79+ group_size : list [int ],
7280):
7381 # Norm
7482 torch_out = None
@@ -81,23 +89,63 @@ def unfused_fp8_impl(
8189 torch_out , _ = ops .scaled_fp8_quant (torch_out )
8290
8391
92+ def unfused_groupwise_fp8_impl (
93+ rms_norm_layer : RMSNorm ,
94+ x : torch .Tensor ,
95+ residual : torch .Tensor | None ,
96+ quant_dtype : torch .dtype ,
97+ group_size : list [int ],
98+ ):
99+ # Norm
100+ torch_out = None
101+ if residual is None :
102+ torch_out = rms_norm_layer .forward_cuda (x , residual )
103+ else :
104+ torch_out , _ = rms_norm_layer .forward_cuda (x , residual )
105+
106+ # Quant
107+ torch_out , _ = per_token_group_quant_fp8 (
108+ torch_out , group_size = group_size [1 ], use_ue8m0 = False
109+ )
110+
111+
84112def fused_impl (
85113 rms_norm_layer : RMSNorm , # this stores the weights
86114 x : torch .Tensor ,
87115 residual : torch .Tensor | None ,
88116 quant_dtype : torch .dtype ,
117+ group_size : list [int ],
89118):
90119 out , _ = ops .rms_norm_dynamic_per_token_quant (
91120 x , rms_norm_layer .weight , 1e-6 , quant_dtype , residual = residual
92121 )
93122
94123
124+ def fused_groupwise_impl (
125+ rms_norm_layer : RMSNorm , # this stores the weights
126+ x : torch .Tensor ,
127+ residual : torch .Tensor | None ,
128+ quant_dtype : torch .dtype ,
129+ group_size : list [int ],
130+ ):
131+ out , _ = ops .rms_norm_per_block_quant (
132+ x ,
133+ rms_norm_layer .weight ,
134+ 1e-6 ,
135+ quant_dtype ,
136+ group_size ,
137+ residual = residual ,
138+ is_scale_transposed = True ,
139+ )
140+
141+
95142# Bench functions
96143def bench_fn (
97144 rms_norm_layer : RMSNorm ,
98145 x : torch .Tensor ,
99146 residual : torch .Tensor ,
100147 quant_dtype : torch .dtype ,
148+ group_size : list [int ],
101149 label : str ,
102150 sub_label : str ,
103151 fn : Callable ,
@@ -110,10 +158,11 @@ def bench_fn(
110158 "x" : x ,
111159 "residual" : residual ,
112160 "quant_dtype" : quant_dtype ,
161+ "group_size" : group_size ,
113162 "fn" : fn ,
114163 }
115164 return TBenchmark .Timer (
116- stmt = "fn(rms_norm_layer, x, residual, quant_dtype)" ,
165+ stmt = "fn(rms_norm_layer, x, residual, quant_dtype, group_size )" ,
117166 globals = globals ,
118167 label = label ,
119168 sub_label = sub_label ,
@@ -147,6 +196,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
147196 x ,
148197 residual ,
149198 torch .int8 ,
199+ params .group_size ,
150200 label ,
151201 sub_label ,
152202 unfused_int8_impl ,
@@ -161,6 +211,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
161211 x ,
162212 residual ,
163213 torch .float8_e4m3fn ,
214+ params .group_size ,
164215 label ,
165216 sub_label ,
166217 unfused_fp8_impl ,
@@ -175,6 +226,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
175226 x ,
176227 residual ,
177228 torch .int8 ,
229+ params .group_size ,
178230 label ,
179231 sub_label ,
180232 fused_impl ,
@@ -189,13 +241,44 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
189241 x ,
190242 residual ,
191243 torch .float8_e4m3fn ,
244+ params .group_size ,
192245 label ,
193246 sub_label ,
194247 fused_impl ,
195248 "fused_fp8_impl" ,
196249 )
197250 )
198251
252+ # unfused groupwise fp8 impl.
253+ timers .append (
254+ bench_fn (
255+ layer ,
256+ x ,
257+ residual ,
258+ torch .float8_e4m3fn ,
259+ params .group_size ,
260+ label ,
261+ sub_label ,
262+ unfused_groupwise_fp8_impl ,
263+ "unfused_groupwise_fp8_impl" ,
264+ )
265+ )
266+
267+ # fused groupwise fp8 impl.
268+ timers .append (
269+ bench_fn (
270+ layer ,
271+ x ,
272+ residual ,
273+ torch .float8_e4m3fn ,
274+ params .group_size ,
275+ label ,
276+ sub_label ,
277+ fused_groupwise_impl ,
278+ "fused_groupwise_fp8_impl" ,
279+ )
280+ )
281+
199282 print_timers (timers )
200283
201284 return timers
0 commit comments