Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def __init__(self):
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.silu_and_mul

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import torch.nn.functional as F

from vllm import _custom_ops as ops
from vllm import envs
from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM
from vllm.platforms import current_platform
from vllm.utils import is_mi250, is_navi

support_tuned_gemms = False
if current_platform.is_rocm():
if current_platform.is_rocm() and not envs.VLLM_USE_V1:
import vllm._gradlib_C # noqa: F401
support_tuned_gemms = True

Expand Down
1 change: 1 addition & 0 deletions vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def forward(
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
fp8_out_scale: Optional[torch.Tensor],
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand Down