Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
Expand Down Expand Up @@ -128,10 +127,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config

self.config = config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = model_config.get_vocab_size()

self.model: PreTrainedModel = AutoModel.from_config(
self.config,
Expand All @@ -145,15 +146,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.apply_base_model_tp_plan(self.model)

# Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size()
num_heads = model_config.get_num_attention_heads(parallel_config)
head_size = model_config.get_head_size()
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.attention_instances = [
Attention(
num_heads=divide(config.num_attention_heads, tp_size),
head_size=config.head_dim,
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward
scale=config.head_dim**-0.5,
num_kv_heads=divide(config.num_key_value_heads, tp_size),
scale=head_size**-0.5,
num_kv_heads=num_kv_heads,
cache_config=cache_config,
quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
Expand All @@ -163,7 +166,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.replace_vocab_embed_class(self.model)

# ForCausalLM modifications
self.lm_head = ParallelLMHead(config.vocab_size,
self.lm_head = ParallelLMHead(self.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
Expand All @@ -172,7 +175,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.vocab_size, logit_scale)
self.sampler = get_sampler()

def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
Expand Down Expand Up @@ -203,12 +206,12 @@ def replace_vocab_embed_class(self, module: nn.Module):
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
org_num_embeddings=self.vocab_size,
quant_config=None,
)
log_replacement("input embedding", self.model.get_input_embeddings(),
new_module)
self.model.set_input_embeddings(new_module)
module.set_input_embeddings(new_module)

def forward(
self,
Expand Down