Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256

VLLM_AITER_TRITON_FP8_BMM: bool = False
VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS: bool = False
VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT: bool = False
VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT: bool = False

def get_default_cache_root():
return os.getenv(
Expand Down Expand Up @@ -728,6 +731,19 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# limit will actually be zero-copy decoded.
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),

"VLLM_AITER_TRITON_FP8_BMM":
lambda: bool(int(os.getenv("VLLM_AITER_TRITON_FP8_BMM", "0"))),

"VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS":
lambda: bool(int(os.getenv("VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS", "0"))),

"VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT":
lambda: bool(int(os.getenv("VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT", "0"))),

"VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT":
lambda: bool(int(os.getenv("VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT", "0"))),

}

# end-env-vars-definition
Expand Down
192 changes: 176 additions & 16 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,37 @@
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT:
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat_and_cache_mla

if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT:
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat_and_cache_mla_q_per_token_fp8_quant
import aiter as rocm_aiter
rocm_aiter_fp8 = rocm_aiter.dtypes.fp8

if envs.VLLM_AITER_TRITON_FP8_BMM:
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()

from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
# @torch.compiler.disable
def aiter_triton_fp8_bmm_wrapper(x, w, w_s, group_size = 128, y = None, transpose_bm = False):
if y is not None:
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(x, w, w_s, group_size = group_size, YQ=y, transpose_bm=transpose_bm)
else:
y = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(x, w, w_s, group_size = group_size, transpose_bm = transpose_bm)
return y

if envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS:
from aiter.ops.triton.fused_concat_zeros import fused_concat_zeros

logger = init_logger(__name__)


Expand Down Expand Up @@ -636,10 +667,14 @@ def __init__(

if self.use_rocm_aiter:
self.rotary_emb = rotary_emb.forward_hip
self.cos_cache, self.sin_cache = rotary_emb.cos_cache, rotary_emb.sin_cache
self.rotary_emb_is_neox_style = rotary_emb.is_neox_style
else:
self.rotary_emb = rotary_emb.forward_native
if current_platform.is_cuda():
self.rotary_emb = rotary_emb.forward_cuda
self.cos_cache, self.sin_cache = rotary_emb.cos_sin_cache.chunk(2, dim = -1)
self.rotary_emb_is_neox_style = rotary_emb.is_neox_style

self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
Expand Down Expand Up @@ -703,10 +738,17 @@ def _flash_attn_varlen_diff_headdims(self,
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
if envs.VLLM_AITER_TRITON_FP8_BMM:
# Multiply + Transpose (N, B, L) x (N, L, V) -> (N, B, V) -> (B, N, V)
# print(f"{x.dtype=}")
x = aiter_triton_fp8_bmm_wrapper(x, self.W_V, self.W_V_scale, group_size = 128, transpose_bm = True)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
else:
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]

# Return `ql_nope`, `q_pe`
Expand All @@ -717,10 +759,15 @@ def _q_proj_and_k_up_proj(self, x):

# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
if envs.VLLM_AITER_TRITON_FP8_BMM:
# Multiply + Transpose (N, B, P) x (N, P, L) -> (N, B, L) -> (B, N, L)
ql_nope = aiter_triton_fp8_bmm_wrapper(q_nope, self.W_K, self.W_K_scale, group_size = 128, transpose_bm = True)
else:
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
return ql_nope, q_pe

def process_weights_after_loading(self, act_dtype: torch.dtype):

Expand Down Expand Up @@ -751,6 +798,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T

assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
Expand All @@ -767,11 +815,89 @@ def get_and_maybe_dequant_weights(layer: LinearBase):

W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)

# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)

if (envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT or envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT):
kv_cache_size = 8192
max_position_embedding = self.cos_cache.shape[0]
for prefill_decode_size in [1, 256, 2048]:
for decode_batch_size in [0, 1, 256]:
if decode_batch_size > prefill_decode_size:
continue

k_scale = torch.ones([1,], dtype=torch.float32, device=W_UK.device)[0]

q = torch.empty((decode_batch_size, self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device)
decode_ql_nope = q[..., :self.kv_lora_rank]
decode_q_pe = q[..., self.kv_lora_rank:]

k = torch.empty((prefill_decode_size, 1, self.kv_lora_rank + self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device)
k_c_normed = k[..., :self.kv_lora_rank].squeeze(1)
k_pe = k[..., self.kv_lora_rank:]

input_positions = torch.randint(0, max_position_embedding, (decode_batch_size, ), device=W_UK.device)
slot_mapping = torch.randperm(kv_cache_size, device=W_UK.device)[:prefill_decode_size]
kv_cache = torch.empty((kv_cache_size, 1, self.kv_lora_rank + self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device)

if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT:
logger.info(f"[Triton] compiling fused_qk_rope_cat_and_cache_mla_q_per_token_fp8_quant with (decode tokens, total tokens) = ({decode_batch_size}, {prefill_decode_size})")
fused_qk_rope_cat_and_cache_mla_q_per_token_fp8_quant(
decode_ql_nope,
decode_q_pe,
k_c_normed.unsqueeze(1),
k_pe,
kv_cache,
slot_mapping,
input_positions,
self.cos_cache,
self.sin_cache,
k_scale,
self.rotary_emb_is_neox_style,
dtype_quant=rocm_aiter_fp8
)
if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT:
logger.info(f"[Triton] compiling fused_qk_rope_cat_and_cache_mla with (decode tokens, total tokens) = ({decode_batch_size}, {prefill_decode_size})")
fused_qk_rope_cat_and_cache_mla(
decode_ql_nope,
decode_q_pe,
k_c_normed.unsqueeze(1),
k_pe,
kv_cache,
slot_mapping,
input_positions,
self.cos_cache,
self.sin_cache,
k_scale,
self.rotary_emb_is_neox_style,
)

if envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS:
max_batch_size = 256
logger.info(f"[Triton] compiling fused_concat_zeros with shape = [1~{max_batch_size}] {self.num_heads} [{self.kv_lora_rank} : {self.qk_rope_head_dim}]")
for m in range(1, max_batch_size+1):
x1 = torch.empty((m, self.num_heads, self.kv_lora_rank), dtype=torch.bfloat16, device=W_UK.device)
x2 = torch.empty((m, self.num_heads, self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device)
fused_concat_zeros(x1, x2)

if envs.VLLM_AITER_TRITON_FP8_BMM:
max_batch_size = 256
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(W_K, dtype=torch.float8_e4m3fnuz)
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(W_V, dtype=torch.float8_e4m3fnuz)
logger.info(f"[Triton] compiling fp8 BMM with shape = {self.W_K.shape[0]} [1~{max_batch_size}] {self.W_K.shape[1]} {self.W_K.shape[2]}")
logger.info(f"[Triton] compiling fp8 BMM with shape = {self.W_V.shape[0]} [1~{max_batch_size}] {self.W_V.shape[1]} {self.W_V.shape[2]}")
for m in range(1, max_batch_size+1):
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), dtype=torch.bfloat16, device=self.W_K.device)
aiter_triton_fp8_bmm_wrapper(x, self.W_K, self.W_K_scale, group_size = 128, transpose_bm = True)

x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), dtype=torch.bfloat16, device=self.W_V.device)
aiter_triton_fp8_bmm_wrapper(x, self.W_V, self.W_V_scale, group_size = 128, transpose_bm = True)

else:
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)

def _compute_prefill_context(
self,
Expand Down Expand Up @@ -951,7 +1077,10 @@ def forward(
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)

if self.use_rocm_aiter:
if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT or envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT:
pass
# the rope operator for decode is now fused with concat_and_cache_mla operator using fused_qk_rope_cat_and_cache_mla
elif self.use_rocm_aiter:
self.rotary_emb(attn_metadata.decode.input_positions,
decode_q_pe, decode_k_pe)
else:
Expand All @@ -974,7 +1103,38 @@ def forward(
prefill_q_pe.contiguous(), prefill_k_pe)

# write the latent and rope to kv cache
if kv_cache.numel() > 0:
q_nope_pe, q_scale = None, None
if (envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT or envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT) and has_decode and kv_cache.numel() > 0:
if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT:
q_nope_pe, q_scale = fused_qk_rope_cat_and_cache_mla_q_per_token_fp8_quant(
decode_ql_nope,
decode_q_pe,
k_c_normed.unsqueeze(1),
k_pe,
kv_cache,
attn_metadata.slot_mapping.flatten(),
attn_metadata.decode.input_positions,
self.cos_cache,
self.sin_cache,
layer._k_scale,
self.rotary_emb_is_neox_style,
dtype_quant=rocm_aiter_fp8
)
else:
q_nope_pe = fused_qk_rope_cat_and_cache_mla(
decode_ql_nope,
decode_q_pe,
k_c_normed.unsqueeze(1),
k_pe,
kv_cache,
attn_metadata.slot_mapping.flatten(),
attn_metadata.decode.input_positions,
self.cos_cache,
self.sin_cache,
layer._k_scale,
self.rotary_emb_is_neox_style,
)
elif kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
Expand All @@ -991,6 +1151,6 @@ def forward(

if has_decode:
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, q_nope_pe=q_nope_pe, q_scale=q_scale)

return output_padded
43 changes: 35 additions & 8 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
MLACommonMetadata,
MLACommonMetadataBuilder)

if envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS:
from aiter.ops.triton.fused_concat_zeros import fused_concat_zeros
# yapf: enable


Expand Down Expand Up @@ -180,18 +182,43 @@ def _forward_decode(
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
q_nope_pe: torch.Tensor = None,
q_scale: torch.Tensor = None,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

B = q_nope.shape[0]

q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT and q_nope_pe is not None and q_scale is not None:
# q_nope_pe.dtype == torch.float8_e4m3fnuz
# q_scale.dtype == torch.float32
# upcast back to bf16 for current implementation, this section can be commented out once aiter_mla_decode_fwd support fp8 and without using zero-tensor output
q = (q_nope_pe.to(torch.float32) * q_scale).to(q_nope.dtype)
B = q_nope.shape[0]
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
elif envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT and q_nope_pe is not None:
# q_nope_pe.dtype == torch.bfloat16
q = q_nope_pe
B = q_nope.shape[0]
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
elif envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS:
q, o = fused_concat_zeros(q_nope, q_pe)
else:
B = q_nope.shape[0]

q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)

kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

Expand Down