Skip to content
43 changes: 43 additions & 0 deletions vllm/compilation/activation_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform

from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass

logger = init_logger(__name__)
Expand Down Expand Up @@ -38,6 +39,32 @@ def silu_mul_replacement_static(result: torch.Tensor,
return at[1]


def silu_mul_mxfp4_gemm_pattern(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(torch.ops.vllm.gemm_with_dynamic_quant.default,
result=result,
x=at1[1],
weight=weight,
weight_scale=scale,
x_scales=None)
return at2[1]


def silu_mul_mxfp4_gemm_replacement(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(torch.ops.vllm.silu_and_mul_mxfp4_gemm.default,
result=result,
x=input,
weight=weight,
weight_scale=scale)
return at[1]


def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")

Expand All @@ -51,6 +78,10 @@ def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


def empty_fp4(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.uint8, device="cuda")


class ActivationQuantFusionPass(VllmInductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
Expand All @@ -61,6 +92,7 @@ class ActivationQuantFusionPass(VllmInductorPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""

@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)

Expand All @@ -76,6 +108,17 @@ def __init__(self, config: VllmConfig):
register_replacement(silu_mul_pattern_static,
silu_mul_replacement_static, inputs, fwd_only,
self.patterns)

inputs = [
empty_bf16(5, 4), # result
empty_bf16(5, 4), # result_silu_mul
empty_bf16(5, 4), # input
empty_fp4(5, 4), # weight
empty_fp4(1, 1), # scale
]
register_replacement(silu_mul_mxfp4_gemm_pattern,
silu_mul_mxfp4_gemm_replacement, inputs, fwd_only,
self.patterns)

def __call__(self, graph: torch.fx.Graph):
self.begin()
Expand Down
20 changes: 20 additions & 0 deletions vllm/compilation/inductor_pass.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import functools
import hashlib
import inspect
import json
import types
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)

import torch
from torch import fx
Expand Down Expand Up @@ -114,3 +117,20 @@ def __call__(self, graph: torch.fx.Graph):

def uuid(self) -> Any:
return self._uuid


def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""

@functools.wraps(fn)
def fn_new(*args, **kwargs) -> Any:
with torch._guards.tracing(
None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)

return result

return fn_new
5 changes: 5 additions & 0 deletions vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass

if current_platform.is_rocm():
from .rocm_fusion import ROCmFusionPass

from .activation_quant_fusion import ActivationQuantFusionPass
from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
Expand Down Expand Up @@ -64,6 +67,8 @@ def configure(self, config: VllmConfig):
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.passes += [ActivationQuantFusionPass(config)]
if current_platform.is_rocm():
self.passes += [ROCmFusionPass(config)]

if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
Expand Down
258 changes: 258 additions & 0 deletions vllm/compilation/rocm_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Callable

import torch
from torch._ops import OpOverload
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
register_replacement, Match)

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform

from .inductor_pass import enable_fake_mode
from .fx_utils import find_getitem_maybe
from .multi_output_match import MultiOutputMatch
from .vllm_inductor_pass import VllmInductorPass


logger = init_logger(__name__)


def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")


def empty_fp8(*args, **kwargs):
fp8 = current_platform.fp8_dtype()
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")


def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


def empty_fp4(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.uint8, device="cuda")


class QuantMultiOutputMatch(MultiOutputMatch):

def __init__(self, match: Match, quant_op, fused_op):
super().__init__(match)
assert isinstance(quant_op, OpOverload)
assert isinstance(fused_op, OpOverload)
self.QUANT_OP = quant_op # in-place quant op
self.FUSED_OP = fused_op # in-place fused quant op

def insert_fused_node(self, fused_return_mapping: dict[int, tuple[torch.fx.Node, int]], **kwargs):
"""
This utility function inserts an auto-functionalized node for FUSED_OP.
It also correctly sets its meta value and rebinds the users of the
unfused nodes to use the fused node instead.

:param fused_return_mapping: A dictionary, mapping from getitem indices
of the fused node result to a tuple of the old node and a getitem index.
:param kwargs: kwargs that get directly forwarded to the auto_fn node

Example:
If we want to replace this graph:
_, x1, x2 = auto_fn(op1)
_, y1, y2 = auto_fn(op2)

with
_, x1, y2, x2 = auto_fn(FUSED_OP)

we would call:
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}

Note that the 0th element is None for auto-functionalized in-place ops.
Hence, others appear 1-indexed.
"""
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
indices = fused_return_mapping.keys()
getitem_nodes = self.insert_getitems(fused_node, indices)

# Prepare the meta value, use a list so it's mutable
meta_val = [None] * (max(indices) + 1)

# Iterate through elements of the tuple produced by fused_node
for idx, getitem_node in zip(indices, getitem_nodes):
old_node, old_idx = fused_return_mapping[idx]

# If the old value was never used, the old_getitem might not exist
old_getitem = find_getitem_maybe(old_node, old_idx)
if old_getitem is not None:
# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
old_getitem.replace_all_uses_with(getitem_node)
getitem_node.meta["val"] = old_getitem.meta["val"]

# Extract the appropriate meta value
# It is present even if the getitem node does not exist
meta_val[idx] = old_node.meta["val"][old_idx]

# Fix the meta value on the new fused node
fused_node.meta["val"] = tuple(meta_val)


ADD_RMS_OP = torch.ops._C.fused_add_rms_norm.default


class AddRMSNormMXFP4GemmPattern:
def __init__(self, epsilon: float):
self.epsilon = epsilon
self.FUSED_OP = torch.ops.vllm.add_rmsnorm_mxfp4_gemm.default
self.QUANT_F4GEMM_OP = torch.ops.vllm.gemm_with_dynamic_quant.default

def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]):

def pattern(
result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, residual_out: torch.Tensor,
residual: torch.Tensor, weight_rms: torch.Tensor,
weight_gemm: torch.Tensor, scale: torch.Tensor):
at1 = auto_functionalized(ADD_RMS_OP,
result=result_rms,
input=input,
residual_out=residual_out,
residual=residual,
weight=weight_rms,
epsilon=self.epsilon)
at2 = auto_functionalized(self.QUANT_F4GEMM_OP,
result=result,
x=at1[1],
weight=weight_gemm,
weight_scale=scale,
x_scales=None)
return at2[1], at1[2]

def replacement(
result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, residual_out: torch.Tensor,
residual: torch.Tensor, weight_rms: torch.Tensor,
weight_gemm: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
residual_out=residual_out,
residual=residual,
weight_rms=weight_rms,
weight_gemm=weight_gemm,
scale=scale,
epsilon=self.epsilon)
return at[1], at[2]

inputs = [
empty_bf16(4, 4), # result
empty_bf16(4, 4), # result_rms
empty_bf16(4, 4), # input
empty_bf16(4, 4), # residual_out
empty_bf16(4, 4), # residual
empty_bf16(1, 4), # weight_rms
empty_fp4(4, 4), # weight_gemm
empty_fp4(1, 1), # scale

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we here empty the dummy tensor for all of arguments of this pattern? And what does the shape here mean?

Copy link
Author

@xytpai xytpai Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

]

register_replacement(
pattern,
replacement,
inputs,
fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_F4GEMM_OP, self.FUSED_OP)))

class Match(QuantMultiOutputMatch):

def process(self):
# Find the nodes in the match that we need to rebind
add_rms_node = self.find_auto_fn(ADD_RMS_OP)
quant_f4gemm_node = self.find_auto_fn(self.QUANT_OP)

assert len(add_rms_node.users) == 2
assert len(quant_f4gemm_node.users) == 1

# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and residual.
# The auto_fn node returns a tuple of (None, result, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
del kwargs["result_rms"] # not used in the fused op
# 0 is always None
fused_return_mapping = {1: (quant_f4gemm_node, 1), 2: (add_rms_node, 2)}
self.insert_fused_node(fused_return_mapping,
**kwargs,
epsilon=add_rms_node.kwargs["epsilon"])


class ROCmFusionPass(VllmInductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.

Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""

@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)

self.matches: list[MultiOutputMatch] = []
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_fusion_pass")

for epsilon in [1e-5, 1e-6]:
AddRMSNormMXFP4GemmPattern(epsilon).register(
self.patterns, self.record_match)

# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()

def record_match(self, match: MultiOutputMatch) -> bool:
# Hijack the extra_check to record the match and
# save it for post-processing.
self.matches.append(match)

# Return False to prevent automatic replacement.
return False

def process_matches(self, graph: torch.fx.Graph):
"""
Manually process multi-output matches and replace them with fused nodes.
See MultiOutputMatch for more details.
"""
for match in self.matches:
match.process()

# Finally, remove matched nodes
graph.eliminate_dead_code()
assert all(node not in graph.nodes for match in self.matches
for node in match.match.nodes)

def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_rocm_fusion")

count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_pattern_match")

# Manually process multi-output matches (and run DCE)
self.process_matches(graph)
logger.debug("Post-processed %s matches", len(self.matches))
self.dump_graph(graph, "after_rocm_fusion")
self.matches.clear()
self.end_and_log()
Loading
Loading