Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 58 additions & 97 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.attention import Attention, QkNormType
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (FusedMoE, Llama4RenormalizeMoeRoutingMethod,
Expand All @@ -32,7 +32,6 @@
WeightsLoadingConfig)
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..modules.rotary_embedding import RotaryEmbedding
from ..speculative import Eagle3SpecMetadata, SpecMetadata
from .modeling_multimodal_utils import fuse_input_embeds
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
Expand All @@ -53,86 +52,51 @@ def __init__(
aux_stream: Optional[torch.cuda.Stream] = None,
):
config = model_config.pretrained_config
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]

self.use_rope = not nope_layer
self.use_qk_norm = use_qk_norm and not nope_layer
if self.use_rope and not self.use_qk_norm:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gptj,
rope=RopeParams.from_config(config),
is_neox=False,
)
else:
pos_embd_params = None

super().__init__(hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config)

if self.use_rope and self.use_qk_norm:
# here we must disable rope fusion regardless of attn_backend
self.enable_rope_fusion = False
self.rotary_emb = RotaryEmbedding(
RopeParams.from_config(config),
head_dim=self.head_dim,
is_neox=False,
)
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gptj,
rope=RopeParams.from_config(config),
is_neox=False,
) if self.use_rope else None

super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
qk_norm_type=QkNormType.post_rope
if use_qk_norm else QkNormType.none,
)

if self.use_qk_norm:
if self.use_rope and use_qk_norm:
self.head_dim = config.hidden_size // config.num_attention_heads
self.qk_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=False)
else:
self.qk_norm = None
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]

self.attn_temperature_tuning = attn_temperature_tuning and nope_layer
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)

def _attn_qkv(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None):
out_scale = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
out_scale = self.o_proj.inv_input_scale
def apply_qk_norm(self, q, k):

q, k, v = self.convert_qkv(q, k, v)
attn_output = self.attn.forward(q,
k,
v,
attn_metadata,
out_scale=out_scale,
attention_mask=attention_mask,
mrope_config=mrope_config)

attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
def q_l2norm():
return self.qk_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)

return attn_output
def k_l2norm():
return self.qk_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)

def _qk_norm(self, q, k):
# TODO: make this more efficient.
q_l2norm = lambda: self.qk_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
k_l2norm = lambda: self.qk_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
Expand All @@ -155,31 +119,6 @@ def _get_attn_scale(position_ids: torch.Tensor) -> torch.Tensor:
q = (q * attn_scale).to(q.dtype)
return q

def _forward_rope(
self,
position_ids: Optional[torch.LongTensor],
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
):
if self.use_qk_norm:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
assert self.rotary_emb is not None and not self.enable_rope_fusion, "qk_norm requires attention rope fusion disabled"
q, k = self.rotary_emb(position_ids, [q, k])
q, k = self._qk_norm(q, k)
return self._attn_qkv(q, k, v, attn_metadata, attention_mask,
mrope_config, all_reduce_params)
else:
# When qk_norm is disabled, use the classic attention path that handles RoPE fusion
return super().forward(position_ids, hidden_states, attn_metadata,
attention_mask, mrope_config,
all_reduce_params)

def _forward_nope(
self,
position_ids: Optional[torch.LongTensor],
Expand All @@ -191,11 +130,26 @@ def _forward_nope(
all_reduce_params: Optional[AllReduceParams] = None,
):
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k, v = self.split_qkv(qkv)
if self.attn_temperature_tuning:
q = self._attention_scaling(q, position_ids)
return self._attn_qkv(q, k, v, attn_metadata, attention_mask,
mrope_config, all_reduce_params)
out_scale = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
out_scale = self.o_proj.inv_input_scale

q, k, v = self.convert_qkv(q, k, v)
attn_output = self.attn.forward(q,
k,
v,
attn_metadata,
out_scale=out_scale,
attention_mask=attention_mask,
mrope_config=mrope_config)

attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)

return attn_output

def forward(
self,
Expand All @@ -211,9 +165,16 @@ def forward(
) -> torch.Tensor:
assert lora_params is None, "LORA is not supported for Llama4Attention"
if self.use_rope:
return self._forward_rope(position_ids, hidden_states,
attn_metadata, attention_mask,
mrope_config, all_reduce_params)
return super().forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=attention_mask,
mrope_config=mrope_config,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
**kwargs,
)
else:
return self._forward_nope(position_ids, hidden_states,
attn_metadata, attention_mask,
Expand Down
24 changes: 23 additions & 1 deletion tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.attention import Attention, QkNormType
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..pipeline_interface import PipelineInterface
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
config=model_config,
qk_norm_type=QkNormType.pre_rope,
)

self.q_norm = RMSNorm(hidden_size=self.head_dim,
Expand All @@ -64,6 +66,26 @@ def __init__(
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]

def apply_qk_norm(self, q, k):

def q_l2norm():
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)

def k_l2norm():
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)

q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)

return q, k


class Qwen3DecoderLayer(DecoderLayer):

Expand Down
51 changes: 2 additions & 49 deletions tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
from tqdm import tqdm
from transformers import Qwen3MoeConfig

from tensorrt_llm.functional import PositionEmbeddingType

from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import FusedMoE, RenormalizeMoeRoutingMethod
from ..modules.linear import Linear, TensorParallelMode
from ..modules.rms_norm import RMSNorm
from .modeling_qwen3 import Qwen3Attention
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
duplicate_kv_weight, register_auto_model)

Expand Down Expand Up @@ -81,57 +78,13 @@ def forward(
return final_hidden_states.view(orig_shape)


class Qwen3MoEAttention(Attention):

def __init__(
self,
model_config: ModelConfig[Qwen3MoeConfig],
layer_idx: Optional[int] = None,
):
config = model_config.pretrained_config
if getattr(config, "rope_scaling", None) is not None:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.from_string(
config.rope_scaling["type"]),
rope=RopeParams.from_config(config),
)
else:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
)
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
config=model_config,
)

self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=True)
self.k_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=True)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]


class Qwen3MoEDecoderLayer(DecoderLayer):

def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
layer_idx: int, aux_stream: torch.cuda.Stream):
super().__init__()
config = model_config.pretrained_config
self.self_attn = Qwen3MoEAttention(
self.self_attn = Qwen3Attention(
model_config,
layer_idx=layer_idx,
)
Expand Down
Loading