Skip to content

Commit af0444b

Browse files
[Performance] Fused blockwise quant RMS norm (vllm-project#27883)
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: yewentao256 <[email protected]> Co-authored-by: yewentao256 <[email protected]>
1 parent 0044c40 commit af0444b

File tree

14 files changed

+946
-154
lines changed

14 files changed

+946
-154
lines changed

benchmarks/fused_kernels/layernorm_rms_benchmarks.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import vllm._custom_ops as ops
1616
from vllm.model_executor.layers.layernorm import RMSNorm
17+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
18+
per_token_group_quant_fp8,
19+
)
1720

1821

1922
@dataclass
@@ -22,13 +25,15 @@ class bench_params_t:
2225
hidden_size: int
2326
add_residual: bool
2427
dtype: torch.dtype
28+
group_size: list[int]
2529

2630
def description(self):
2731
return (
2832
f"N {self.num_tokens} "
2933
f"x D {self.hidden_size} "
3034
f"x R {self.add_residual} "
3135
f"x DT {self.dtype}"
36+
f"x GS {self.group_size}"
3237
)
3338

3439

@@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]:
3843
HIDDEN_SIZES = list(range(1024, 8129, 1024))
3944
ADD_RESIDUAL = [True, False]
4045
DTYPES = [torch.bfloat16, torch.float]
46+
GROUP_SIZES = [[1, 64], [1, 128]]
4147

42-
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
48+
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES)
4349
bench_params = list(
44-
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
50+
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations)
4551
)
4652
return bench_params
4753

@@ -52,6 +58,7 @@ def unfused_int8_impl(
5258
x: torch.Tensor,
5359
residual: torch.Tensor | None,
5460
quant_dtype: torch.dtype,
61+
group_size: list[int],
5562
):
5663
# Norm
5764
torch_out = None
@@ -69,6 +76,7 @@ def unfused_fp8_impl(
6976
x: torch.Tensor,
7077
residual: torch.Tensor | None,
7178
quant_dtype: torch.dtype,
79+
group_size: list[int],
7280
):
7381
# Norm
7482
torch_out = None
@@ -81,23 +89,63 @@ def unfused_fp8_impl(
8189
torch_out, _ = ops.scaled_fp8_quant(torch_out)
8290

8391

92+
def unfused_groupwise_fp8_impl(
93+
rms_norm_layer: RMSNorm,
94+
x: torch.Tensor,
95+
residual: torch.Tensor | None,
96+
quant_dtype: torch.dtype,
97+
group_size: list[int],
98+
):
99+
# Norm
100+
torch_out = None
101+
if residual is None:
102+
torch_out = rms_norm_layer.forward_cuda(x, residual)
103+
else:
104+
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
105+
106+
# Quant
107+
torch_out, _ = per_token_group_quant_fp8(
108+
torch_out, group_size=group_size[1], use_ue8m0=False
109+
)
110+
111+
84112
def fused_impl(
85113
rms_norm_layer: RMSNorm, # this stores the weights
86114
x: torch.Tensor,
87115
residual: torch.Tensor | None,
88116
quant_dtype: torch.dtype,
117+
group_size: list[int],
89118
):
90119
out, _ = ops.rms_norm_dynamic_per_token_quant(
91120
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
92121
)
93122

94123

124+
def fused_groupwise_impl(
125+
rms_norm_layer: RMSNorm, # this stores the weights
126+
x: torch.Tensor,
127+
residual: torch.Tensor | None,
128+
quant_dtype: torch.dtype,
129+
group_size: list[int],
130+
):
131+
out, _ = ops.rms_norm_per_block_quant(
132+
x,
133+
rms_norm_layer.weight,
134+
1e-6,
135+
quant_dtype,
136+
group_size,
137+
residual=residual,
138+
is_scale_transposed=True,
139+
)
140+
141+
95142
# Bench functions
96143
def bench_fn(
97144
rms_norm_layer: RMSNorm,
98145
x: torch.Tensor,
99146
residual: torch.Tensor,
100147
quant_dtype: torch.dtype,
148+
group_size: list[int],
101149
label: str,
102150
sub_label: str,
103151
fn: Callable,
@@ -110,10 +158,11 @@ def bench_fn(
110158
"x": x,
111159
"residual": residual,
112160
"quant_dtype": quant_dtype,
161+
"group_size": group_size,
113162
"fn": fn,
114163
}
115164
return TBenchmark.Timer(
116-
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
165+
stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)",
117166
globals=globals,
118167
label=label,
119168
sub_label=sub_label,
@@ -147,6 +196,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
147196
x,
148197
residual,
149198
torch.int8,
199+
params.group_size,
150200
label,
151201
sub_label,
152202
unfused_int8_impl,
@@ -161,6 +211,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
161211
x,
162212
residual,
163213
torch.float8_e4m3fn,
214+
params.group_size,
164215
label,
165216
sub_label,
166217
unfused_fp8_impl,
@@ -175,6 +226,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
175226
x,
176227
residual,
177228
torch.int8,
229+
params.group_size,
178230
label,
179231
sub_label,
180232
fused_impl,
@@ -189,13 +241,44 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
189241
x,
190242
residual,
191243
torch.float8_e4m3fn,
244+
params.group_size,
192245
label,
193246
sub_label,
194247
fused_impl,
195248
"fused_fp8_impl",
196249
)
197250
)
198251

252+
# unfused groupwise fp8 impl.
253+
timers.append(
254+
bench_fn(
255+
layer,
256+
x,
257+
residual,
258+
torch.float8_e4m3fn,
259+
params.group_size,
260+
label,
261+
sub_label,
262+
unfused_groupwise_fp8_impl,
263+
"unfused_groupwise_fp8_impl",
264+
)
265+
)
266+
267+
# fused groupwise fp8 impl.
268+
timers.append(
269+
bench_fn(
270+
layer,
271+
x,
272+
residual,
273+
torch.float8_e4m3fn,
274+
params.group_size,
275+
label,
276+
sub_label,
277+
fused_groupwise_impl,
278+
"fused_groupwise_fp8_impl",
279+
)
280+
)
281+
199282
print_timers(timers)
200283

201284
return timers

csrc/dispatch_utils.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,24 @@
118118
} \
119119
}
120120

121+
#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \
122+
if (expr) { \
123+
constexpr bool const_expr = true; \
124+
__VA_ARGS__(); \
125+
} else { \
126+
constexpr bool const_expr = false; \
127+
__VA_ARGS__(); \
128+
}
129+
130+
#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \
131+
if (group_size == 128) { \
132+
constexpr int const_group_size = 128; \
133+
__VA_ARGS__(); \
134+
} else if (group_size == 64) { \
135+
constexpr int const_group_size = 64; \
136+
__VA_ARGS__(); \
137+
}
138+
121139
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
122140
switch (NUM_DIMS) { \
123141
case 2: { \

csrc/ops.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
128128
std::optional<torch::Tensor> scale_ub,
129129
std::optional<torch::Tensor> residual);
130130

131+
void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
132+
torch::Tensor const& weight,
133+
torch::Tensor& scales, double const epsilon,
134+
std::optional<torch::Tensor> scale_ub,
135+
std::optional<torch::Tensor> residual,
136+
int64_t group_size, bool is_scale_transposed);
137+
131138
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
132139
std::optional<torch::Tensor> key, int64_t head_size,
133140
torch::Tensor& cos_sin_cache, bool is_neox);

0 commit comments

Comments
 (0)