Skip to content

Commit b508167

Browse files
wenscarlyeqcharlotte
authored andcommitted
Update fp4 quantize API (vllm-project#21327)
Signed-off-by: Shu Wang <[email protected]>
1 parent a8ace96 commit b508167

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ def apply(
181181
g2_alphas,
182182
]
183183
_ = flashinfer_cutlass_fused_moe(
184-
hidden_states,
185-
topk_ids.to(torch.int),
186-
topk_weights,
184+
input=hidden_states,
185+
token_selected_experts=topk_ids.to(torch.int),
186+
token_final_scales=topk_weights,
187187
# FlashInfer API requires weight to be long for nvfp4
188-
w1.view(torch.long),
189-
w2.view(torch.long),
188+
fc1_expert_weights=w1.view(torch.long),
189+
fc2_expert_weights=w2.view(torch.long),
190190
output_dtype=out_dtype,
191191
quant_scales=quant_scales,
192192
input_sf=a1q_scale,

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
1212
from vllm.model_executor.layers.fused_moe.utils import (
1313
extract_required_args, moe_kernel_quantize_input)
14-
from vllm.utils.flashinfer import fp4_swizzle_blockscale
14+
from vllm.utils.flashinfer import block_scale_interleave
1515

1616

1717
def get_local_sizes(local_tokens):
@@ -92,7 +92,7 @@ def prepare(
9292
dim=0,
9393
sizes=get_local_sizes(local_tokens))
9494
a1_m, a1_n = a1q.shape
95-
a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2)
95+
a1q_scale = block_scale_interleave(a1q_scale)
9696

9797
return a1q, a1q_scale, None, topk_ids, topk_weights
9898

vllm/utils/flashinfer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def wrapper(*args, **kwargs):
6969
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
7070
"cutlass_fused_moe")
7171
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
72-
fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer",
73-
"fp4_swizzle_blockscale")
72+
block_scale_interleave = _lazy_import_wrapper("flashinfer",
73+
"block_scale_interleave")
7474

7575
# Special case for autotune since it returns a context manager
7676
autotune = _lazy_import_wrapper(
@@ -95,7 +95,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
9595
required_functions = [
9696
("flashinfer.fused_moe", "cutlass_fused_moe"),
9797
("flashinfer", "fp4_quantize"),
98-
("flashinfer", "fp4_swizzle_blockscale"),
98+
("flashinfer", "block_scale_interleave"),
9999
]
100100

101101
for module_name, attr_name in required_functions:
@@ -110,7 +110,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
110110
"flashinfer_trtllm_fp8_block_scale_moe",
111111
"flashinfer_cutlass_fused_moe",
112112
"fp4_quantize",
113-
"fp4_swizzle_blockscale",
113+
"block_scale_interleave",
114114
"autotune",
115115
"has_flashinfer_moe",
116116
"has_flashinfer_cutlass_fused_moe",

0 commit comments

Comments
 (0)