|
43 | 43 | from vllm.logger import init_logger
|
44 | 44 | from vllm.model_executor.layers.layernorm import RMSNorm
|
45 | 45 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
46 |
| - QKVCrossParallelLinear, |
47 | 46 | QKVParallelLinear,
|
48 | 47 | RowParallelLinear)
|
49 | 48 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
@@ -814,11 +813,20 @@ def __init__(
|
814 | 813 | self.q_local_size = self.num_local_heads * self.head_dim
|
815 | 814 | self.kv_local_size = self.num_local_key_value_heads * self.head_dim
|
816 | 815 |
|
817 |
| - self.qkv_proj = QKVCrossParallelLinear( |
| 816 | + # TODO(Isotr0py): Use QKVCrossParallelLinear when it supports |
| 817 | + # quantization |
| 818 | + self.q_proj = ColumnParallelLinear( |
| 819 | + input_size=self.hidden_size, |
| 820 | + output_size=self.num_heads * self.head_dim, |
| 821 | + bias=False, |
| 822 | + quant_config=quant_config, |
| 823 | + prefix=f"{prefix}.q_proj", |
| 824 | + ) |
| 825 | + self.kv_proj = QKVParallelLinear( |
818 | 826 | self.hidden_size,
|
819 | 827 | self.head_dim,
|
820 |
| - self.num_heads, |
821 |
| - self.num_key_value_heads, |
| 828 | + total_num_heads=0, |
| 829 | + total_num_kv_heads=self.num_key_value_heads, |
822 | 830 | bias=False,
|
823 | 831 | quant_config=quant_config,
|
824 | 832 | prefix=f"{prefix}.qkv_proj",
|
@@ -854,11 +862,15 @@ def forward(
|
854 | 862 | kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
855 | 863 | cross_attention_states: Optional[torch.Tensor],
|
856 | 864 | ) -> torch.Tensor:
|
857 |
| - q, k, v = self.qkv_proj(hidden_states, cross_attention_states) |
| 865 | + q, _ = self.q_proj(hidden_states) |
858 | 866 | if cross_attention_states is not None:
|
| 867 | + kv, _ = self.kv_proj(cross_attention_states) |
| 868 | + k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1) |
859 | 869 | k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
|
860 | 870 | v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
|
861 | 871 | k = self.k_norm(k)
|
| 872 | + else: |
| 873 | + k = v = None |
862 | 874 |
|
863 | 875 | q = q.view(-1, self.num_local_heads, self.head_dim)
|
864 | 876 | q = self.q_norm(q)
|
@@ -1149,8 +1161,13 @@ def forward(
|
1149 | 1161 | class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
1150 | 1162 | SupportsV0Only):
|
1151 | 1163 | packed_modules_mapping = {
|
1152 |
| - "qkv_proj": ["q_proj", "k_proj", "v_proj"], |
1153 |
| - "gate_up_proj": ["gate_proj", "up_proj"] |
| 1164 | + "self_attn.qkv_proj": [ |
| 1165 | + "self_attn.q_proj", |
| 1166 | + "self_attn.k_proj", |
| 1167 | + "self_attn.v_proj", |
| 1168 | + ], |
| 1169 | + "cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"], |
| 1170 | + "gate_up_proj": ["gate_proj", "up_proj"], |
1154 | 1171 | }
|
1155 | 1172 |
|
1156 | 1173 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
@@ -1420,9 +1437,11 @@ def load_weights(self, weights: Iterable[Tuple[str,
|
1420 | 1437 | torch.Tensor]]) -> Set[str]:
|
1421 | 1438 | stacked_params_mapping = [
|
1422 | 1439 | # (param_name, shard_name, shard_id)
|
1423 |
| - (".qkv_proj", ".q_proj", "q"), |
1424 |
| - (".qkv_proj", ".k_proj", "k"), |
1425 |
| - (".qkv_proj", ".v_proj", "v"), |
| 1440 | + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), |
| 1441 | + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), |
| 1442 | + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), |
| 1443 | + (".cross_attn.kv_proj", ".cross_attn.k_proj", "k"), |
| 1444 | + (".cross_attn.kv_proj", ".cross_attn.v_proj", "v"), |
1426 | 1445 | (".gate_up_proj", ".gate_proj", 0),
|
1427 | 1446 | (".gate_up_proj", ".up_proj", 1),
|
1428 | 1447 | ]
|
|
0 commit comments