Skip to content

Commit d07b3a1

Browse files
Isotr0pyDamonFool
authored andcommitted
[Bugfix] Revert QKVCrossParallelLinear usage in Mllama to keep BNB quantization work (vllm-project#14498)
Signed-off-by: Isotr0py <[email protected]>
1 parent 796f871 commit d07b3a1

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

vllm/model_executor/models/mllama.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from vllm.logger import init_logger
4444
from vllm.model_executor.layers.layernorm import RMSNorm
4545
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
46-
QKVCrossParallelLinear,
4746
QKVParallelLinear,
4847
RowParallelLinear)
4948
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -814,11 +813,20 @@ def __init__(
814813
self.q_local_size = self.num_local_heads * self.head_dim
815814
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
816815

817-
self.qkv_proj = QKVCrossParallelLinear(
816+
# TODO(Isotr0py): Use QKVCrossParallelLinear when it supports
817+
# quantization
818+
self.q_proj = ColumnParallelLinear(
819+
input_size=self.hidden_size,
820+
output_size=self.num_heads * self.head_dim,
821+
bias=False,
822+
quant_config=quant_config,
823+
prefix=f"{prefix}.q_proj",
824+
)
825+
self.kv_proj = QKVParallelLinear(
818826
self.hidden_size,
819827
self.head_dim,
820-
self.num_heads,
821-
self.num_key_value_heads,
828+
total_num_heads=0,
829+
total_num_kv_heads=self.num_key_value_heads,
822830
bias=False,
823831
quant_config=quant_config,
824832
prefix=f"{prefix}.qkv_proj",
@@ -854,11 +862,15 @@ def forward(
854862
kv_range_for_decode: Optional[List[Tuple[int, int]]],
855863
cross_attention_states: Optional[torch.Tensor],
856864
) -> torch.Tensor:
857-
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
865+
q, _ = self.q_proj(hidden_states)
858866
if cross_attention_states is not None:
867+
kv, _ = self.kv_proj(cross_attention_states)
868+
k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1)
859869
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
860870
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
861871
k = self.k_norm(k)
872+
else:
873+
k = v = None
862874

863875
q = q.view(-1, self.num_local_heads, self.head_dim)
864876
q = self.q_norm(q)
@@ -1149,8 +1161,13 @@ def forward(
11491161
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
11501162
SupportsV0Only):
11511163
packed_modules_mapping = {
1152-
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
1153-
"gate_up_proj": ["gate_proj", "up_proj"]
1164+
"self_attn.qkv_proj": [
1165+
"self_attn.q_proj",
1166+
"self_attn.k_proj",
1167+
"self_attn.v_proj",
1168+
],
1169+
"cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"],
1170+
"gate_up_proj": ["gate_proj", "up_proj"],
11541171
}
11551172

11561173
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -1420,9 +1437,11 @@ def load_weights(self, weights: Iterable[Tuple[str,
14201437
torch.Tensor]]) -> Set[str]:
14211438
stacked_params_mapping = [
14221439
# (param_name, shard_name, shard_id)
1423-
(".qkv_proj", ".q_proj", "q"),
1424-
(".qkv_proj", ".k_proj", "k"),
1425-
(".qkv_proj", ".v_proj", "v"),
1440+
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
1441+
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
1442+
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
1443+
(".cross_attn.kv_proj", ".cross_attn.k_proj", "k"),
1444+
(".cross_attn.kv_proj", ".cross_attn.v_proj", "v"),
14261445
(".gate_up_proj", ".gate_proj", 0),
14271446
(".gate_up_proj", ".up_proj", 1),
14281447
]

0 commit comments

Comments
 (0)