|
| 1 | +from typing import Iterable, List, Optional, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | + |
| 6 | +from vllm.attention import AttentionMetadata |
| 7 | +from vllm.model_executor.layers.pooler import Pooler, PoolingType |
| 8 | +from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 9 | +from vllm.model_executor.models.gemma2 import Gemma2Model |
| 10 | +from vllm.model_executor.pooling_metadata import PoolingMetadata |
| 11 | +from vllm.sequence import IntermediateTensors, PoolerOutput |
| 12 | + |
| 13 | + |
| 14 | +class Gemma2EmbeddingModel(nn.Module): |
| 15 | + """A model that uses Gemma2 with additional embedding functionalities. |
| 16 | +
|
| 17 | + This class encapsulates the Gemma2Model and provides an interface for |
| 18 | + embedding operations and customized pooling functions. |
| 19 | +
|
| 20 | + Attributes: |
| 21 | + model: An instance of Gemma2Model used for forward operations. |
| 22 | + _pooler: An instance of Pooler used for pooling operations. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + **kwargs, |
| 28 | + ) -> None: |
| 29 | + super().__init__() |
| 30 | + self.model = Gemma2Model(**kwargs) |
| 31 | + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) |
| 32 | + |
| 33 | + def forward( |
| 34 | + self, |
| 35 | + input_ids: Optional[torch.Tensor], |
| 36 | + positions: torch.Tensor, |
| 37 | + kv_caches: List[torch.Tensor], |
| 38 | + attn_metadata: AttentionMetadata, |
| 39 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 40 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 41 | + ) -> torch.Tensor: |
| 42 | + return self.model.forward(input_ids, positions, kv_caches, |
| 43 | + attn_metadata, intermediate_tensors, |
| 44 | + inputs_embeds) |
| 45 | + |
| 46 | + def pooler( |
| 47 | + self, |
| 48 | + hidden_states: torch.Tensor, |
| 49 | + pooling_metadata: PoolingMetadata, |
| 50 | + ) -> Optional[PoolerOutput]: |
| 51 | + return self._pooler(hidden_states, pooling_metadata) |
| 52 | + |
| 53 | + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
| 54 | + stacked_params_mapping = [ |
| 55 | + # (param_name, shard_name, shard_id) |
| 56 | + ("qkv_proj", "q_proj", "q"), |
| 57 | + ("qkv_proj", "k_proj", "k"), |
| 58 | + ("qkv_proj", "v_proj", "v"), |
| 59 | + ("gate_up_proj", "gate_proj", 0), |
| 60 | + ("gate_up_proj", "up_proj", 1), |
| 61 | + ] |
| 62 | + params_dict = dict(self.model.named_parameters()) |
| 63 | + for name, loaded_weight in weights: |
| 64 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 65 | + if weight_name not in name: |
| 66 | + continue |
| 67 | + name = name.replace(weight_name, param_name) |
| 68 | + # Skip loading extra bias for GPTQ models. |
| 69 | + if name.endswith(".bias") and name not in params_dict: |
| 70 | + continue |
| 71 | + param = params_dict[name] |
| 72 | + weight_loader = param.weight_loader |
| 73 | + weight_loader(param, loaded_weight, shard_id) |
| 74 | + break |
| 75 | + else: |
| 76 | + # Skip loading extra bias for GPTQ models. |
| 77 | + if name.endswith(".bias") and name not in params_dict: |
| 78 | + continue |
| 79 | + param = params_dict[name] |
| 80 | + weight_loader = getattr(param, "weight_loader", |
| 81 | + default_weight_loader) |
| 82 | + weight_loader(param, loaded_weight) |
0 commit comments