|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""Inference-only GraniteMoeShared model. |
| 3 | +
|
| 4 | +The architecture is the same as granitemoe but with the addition of shared |
| 5 | +experts. |
| 6 | +""" |
| 7 | +from typing import Iterable, Optional, Set, Tuple |
| 8 | + |
| 9 | +import torch |
| 10 | +from torch import nn |
| 11 | +from transformers.models.granitemoeshared import GraniteMoeSharedConfig |
| 12 | + |
| 13 | +from vllm.compilation.decorators import support_torch_compile |
| 14 | +from vllm.config import CacheConfig, VllmConfig |
| 15 | +from vllm.distributed import get_pp_group |
| 16 | +from vllm.model_executor.layers.activation import SiluAndMul |
| 17 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 18 | +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, |
| 19 | + RowParallelLinear) |
| 20 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| 21 | +from vllm.model_executor.layers.quantization.base_config import ( |
| 22 | + QuantizationConfig) |
| 23 | +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler |
| 24 | +from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| 25 | + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) |
| 26 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 27 | +from vllm.sequence import IntermediateTensors |
| 28 | + |
| 29 | +from . import mixtral |
| 30 | +from .granitemoe import GraniteMoeAttention, GraniteMoeMoE |
| 31 | +from .interfaces import SupportsLoRA, SupportsPP |
| 32 | +from .utils import make_layers, maybe_prefix |
| 33 | + |
| 34 | + |
| 35 | +class GraniteMoeSharedMLP(nn.Module): |
| 36 | + |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + config: GraniteMoeSharedConfig, |
| 40 | + quant_config: Optional[QuantizationConfig] = None, |
| 41 | + prefix: str = "", |
| 42 | + ): |
| 43 | + super().__init__() |
| 44 | + |
| 45 | + self.input_size = config.hidden_size |
| 46 | + self.hidden_size = config.shared_intermediate_size |
| 47 | + self.input_linear = MergedColumnParallelLinear( |
| 48 | + input_size=self.input_size, |
| 49 | + output_sizes=[self.hidden_size] * 2, |
| 50 | + bias=False, |
| 51 | + quant_config=quant_config, |
| 52 | + prefix=f"{prefix}.input_linear") |
| 53 | + self.output_linear = RowParallelLinear( |
| 54 | + self.hidden_size, |
| 55 | + self.input_size, |
| 56 | + bias=False, |
| 57 | + quant_config=quant_config, |
| 58 | + prefix=f"{prefix}.output_linear") |
| 59 | + if config.hidden_act != "silu": |
| 60 | + raise ValueError(f"Unsupported activation: {config.hidden_act}. " |
| 61 | + "Only silu is supported for now.") |
| 62 | + self.act_fn = SiluAndMul() |
| 63 | + |
| 64 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 65 | + hidden_states, _ = self.input_linear(hidden_states) |
| 66 | + hidden_states = self.act_fn(hidden_states) |
| 67 | + hidden_states, _ = self.output_linear(hidden_states) |
| 68 | + return hidden_states |
| 69 | + |
| 70 | + |
| 71 | +class GraniteMoeSharedDecoderLayer(nn.Module): |
| 72 | + |
| 73 | + def __init__( |
| 74 | + self, |
| 75 | + config: GraniteMoeSharedConfig, |
| 76 | + cache_config: Optional[CacheConfig] = None, |
| 77 | + quant_config: Optional[QuantizationConfig] = None, |
| 78 | + prefix: str = "", |
| 79 | + ) -> None: |
| 80 | + super().__init__() |
| 81 | + self.hidden_size = config.hidden_size |
| 82 | + # Requires transformers > 4.32.0 |
| 83 | + rope_theta = getattr(config, "rope_theta", 10000) |
| 84 | + self.self_attn = GraniteMoeAttention( |
| 85 | + hidden_size=self.hidden_size, |
| 86 | + num_heads=config.num_attention_heads, |
| 87 | + max_position=config.max_position_embeddings, |
| 88 | + num_kv_heads=config.num_key_value_heads, |
| 89 | + rope_theta=rope_theta, |
| 90 | + cache_config=cache_config, |
| 91 | + quant_config=quant_config, |
| 92 | + prefix=f"{prefix}.self_attn", |
| 93 | + attention_multiplier=config.attention_multiplier) |
| 94 | + self.block_sparse_moe = GraniteMoeMoE( |
| 95 | + num_experts=config.num_local_experts, |
| 96 | + top_k=config.num_experts_per_tok, |
| 97 | + hidden_size=config.hidden_size, |
| 98 | + intermediate_size=config.intermediate_size, |
| 99 | + quant_config=quant_config, |
| 100 | + prefix=f"{prefix}.block_sparse_moe") |
| 101 | + self.shared_mlp = None if \ |
| 102 | + getattr(config, 'shared_intermediate_size', 0) == 0 \ |
| 103 | + else GraniteMoeSharedMLP( |
| 104 | + config, |
| 105 | + quant_config=quant_config, |
| 106 | + prefix=f"{prefix}.shared_mlp" |
| 107 | + ) |
| 108 | + |
| 109 | + self.input_layernorm = RMSNorm(config.hidden_size, |
| 110 | + eps=config.rms_norm_eps) |
| 111 | + self.post_attention_layernorm = RMSNorm(config.hidden_size, |
| 112 | + eps=config.rms_norm_eps) |
| 113 | + |
| 114 | + self.residual_multiplier = config.residual_multiplier |
| 115 | + |
| 116 | + def forward( |
| 117 | + self, |
| 118 | + positions: torch.Tensor, |
| 119 | + hidden_states: torch.Tensor, |
| 120 | + ) -> torch.Tensor: |
| 121 | + # Self Attention |
| 122 | + residual = hidden_states |
| 123 | + hidden_states = self.input_layernorm(hidden_states) |
| 124 | + hidden_states = self.self_attn( |
| 125 | + positions=positions, |
| 126 | + hidden_states=hidden_states, |
| 127 | + ) |
| 128 | + hidden_states = residual + hidden_states * self.residual_multiplier |
| 129 | + residual = hidden_states |
| 130 | + hidden_states = self.post_attention_layernorm(hidden_states) |
| 131 | + if self.shared_mlp is None: |
| 132 | + hidden_states = self.block_sparse_moe(hidden_states) |
| 133 | + else: |
| 134 | + # create a copy since block_sparse_moe modifies in-place |
| 135 | + moe_hidden_states = hidden_states.clone() |
| 136 | + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) |
| 137 | + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) |
| 138 | + del moe_hidden_states |
| 139 | + hidden_states = residual + hidden_states * self.residual_multiplier |
| 140 | + |
| 141 | + return hidden_states |
| 142 | + |
| 143 | + |
| 144 | +@support_torch_compile |
| 145 | +class GraniteMoeSharedModel(nn.Module): |
| 146 | + |
| 147 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 148 | + super().__init__() |
| 149 | + |
| 150 | + config = vllm_config.model_config.hf_config |
| 151 | + cache_config = vllm_config.cache_config |
| 152 | + quant_config = vllm_config.quant_config |
| 153 | + lora_config = vllm_config.lora_config |
| 154 | + |
| 155 | + self.padding_idx = config.pad_token_id |
| 156 | + lora_vocab = (lora_config.lora_extra_vocab_size * |
| 157 | + (lora_config.max_loras or 1)) if lora_config else 0 |
| 158 | + self.vocab_size = config.vocab_size + lora_vocab |
| 159 | + self.org_vocab_size = config.vocab_size |
| 160 | + |
| 161 | + self.embed_tokens = VocabParallelEmbedding( |
| 162 | + self.vocab_size, |
| 163 | + config.hidden_size, |
| 164 | + org_num_embeddings=config.vocab_size, |
| 165 | + quant_config=quant_config, |
| 166 | + ) |
| 167 | + self.embedding_multiplier = config.embedding_multiplier |
| 168 | + |
| 169 | + self.start_layer, self.end_layer, self.layers = make_layers( |
| 170 | + config.num_hidden_layers, |
| 171 | + lambda prefix: GraniteMoeSharedDecoderLayer( |
| 172 | + config, cache_config, quant_config=quant_config, prefix=prefix |
| 173 | + ), |
| 174 | + prefix=f"{prefix}.layers") |
| 175 | + |
| 176 | + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 177 | + |
| 178 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 179 | + return self.embed_tokens(input_ids) |
| 180 | + |
| 181 | + def forward( |
| 182 | + self, |
| 183 | + input_ids: torch.Tensor, |
| 184 | + positions: torch.Tensor, |
| 185 | + intermediate_tensors: Optional[IntermediateTensors], |
| 186 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 187 | + ) -> torch.Tensor: |
| 188 | + if get_pp_group().is_first_rank: |
| 189 | + if inputs_embeds is not None: |
| 190 | + hidden_states = inputs_embeds |
| 191 | + else: |
| 192 | + hidden_states = self.get_input_embeddings(input_ids) |
| 193 | + hidden_states *= self.embedding_multiplier |
| 194 | + residual = None |
| 195 | + else: |
| 196 | + assert intermediate_tensors is not None |
| 197 | + hidden_states = intermediate_tensors["hidden_states"] |
| 198 | + residual = intermediate_tensors["residual"] |
| 199 | + for i in range(self.start_layer, self.end_layer): |
| 200 | + layer = self.layers[i] |
| 201 | + hidden_states = layer(positions, hidden_states) |
| 202 | + if not get_pp_group().is_last_rank: |
| 203 | + return IntermediateTensors({ |
| 204 | + "hidden_states": hidden_states, |
| 205 | + "residual": residual |
| 206 | + }) |
| 207 | + hidden_states = self.norm(hidden_states) |
| 208 | + return hidden_states |
| 209 | + |
| 210 | + |
| 211 | +class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): |
| 212 | + fall_back_to_pt_during_load = False |
| 213 | + |
| 214 | + packed_modules_mapping = { |
| 215 | + "qkv_proj": [ |
| 216 | + "q_proj", |
| 217 | + "k_proj", |
| 218 | + "v_proj", |
| 219 | + ], |
| 220 | + } |
| 221 | + |
| 222 | + # LoRA specific attributes |
| 223 | + embedding_modules = { |
| 224 | + "embed_tokens": "input_embeddings", |
| 225 | + "lm_head": "output_embeddings", |
| 226 | + } |
| 227 | + embedding_padding_modules = ["lm_head"] |
| 228 | + |
| 229 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 230 | + super().__init__() |
| 231 | + config = vllm_config.model_config.hf_config |
| 232 | + quant_config = vllm_config.quant_config |
| 233 | + lora_config = vllm_config.lora_config |
| 234 | + |
| 235 | + self.config = config |
| 236 | + self.lora_config = lora_config |
| 237 | + self.quant_config = quant_config |
| 238 | + |
| 239 | + self.model = GraniteMoeSharedModel(vllm_config=vllm_config, |
| 240 | + prefix=maybe_prefix( |
| 241 | + prefix, "model")) |
| 242 | + self.unpadded_vocab_size = config.vocab_size |
| 243 | + if lora_config: |
| 244 | + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size |
| 245 | + self.lm_head = ParallelLMHead( |
| 246 | + self.unpadded_vocab_size, |
| 247 | + config.hidden_size, |
| 248 | + org_num_embeddings=config.vocab_size, |
| 249 | + padding_size=DEFAULT_VOCAB_PADDING_SIZE |
| 250 | + # We need bigger padding if using lora for kernel |
| 251 | + # compatibility |
| 252 | + if not lora_config else lora_config.lora_vocab_padding_size, |
| 253 | + quant_config=quant_config, |
| 254 | + prefix=maybe_prefix(prefix, "lm_head")) |
| 255 | + if config.tie_word_embeddings: |
| 256 | + self.lm_head.weight = self.model.embed_tokens.weight |
| 257 | + |
| 258 | + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, |
| 259 | + config.vocab_size, |
| 260 | + scale=1 / |
| 261 | + self.config.logits_scaling) |
| 262 | + |
| 263 | + self.sampler = get_sampler() |
| 264 | + |
| 265 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 266 | + return self.model.get_input_embeddings(input_ids) |
| 267 | + |
| 268 | + def forward( |
| 269 | + self, |
| 270 | + input_ids: torch.Tensor, |
| 271 | + positions: torch.Tensor, |
| 272 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 273 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 274 | + ) -> torch.Tensor: |
| 275 | + hidden_states = self.model(input_ids, positions, intermediate_tensors, |
| 276 | + inputs_embeds) |
| 277 | + return hidden_states |
| 278 | + |
| 279 | + def compute_logits( |
| 280 | + self, hidden_states: torch.Tensor, |
| 281 | + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: |
| 282 | + logits = self.logits_processor(self.lm_head, hidden_states, |
| 283 | + sampling_metadata) |
| 284 | + return logits |
| 285 | + |
| 286 | + def make_empty_intermediate_tensors( |
| 287 | + self, batch_size: int, dtype: torch.dtype, |
| 288 | + device: torch.device) -> IntermediateTensors: |
| 289 | + return IntermediateTensors({ |
| 290 | + "hidden_states": |
| 291 | + torch.zeros((batch_size, self.config.hidden_size), |
| 292 | + dtype=dtype, |
| 293 | + device=device), |
| 294 | + "residual": |
| 295 | + torch.zeros((batch_size, self.config.hidden_size), |
| 296 | + dtype=dtype, |
| 297 | + device=device), |
| 298 | + }) |
| 299 | + |
| 300 | + def sample( |
| 301 | + self, |
| 302 | + logits: Optional[torch.Tensor], |
| 303 | + sampling_metadata: SamplingMetadata, |
| 304 | + ) -> Optional[SamplerOutput]: |
| 305 | + next_tokens = self.sampler(logits, sampling_metadata) |
| 306 | + return next_tokens |
| 307 | + |
| 308 | + def load_weights(self, weights: Iterable[Tuple[str, |
| 309 | + torch.Tensor]]) -> Set[str]: |
| 310 | + new_weights = {} |
| 311 | + for n, p in weights: |
| 312 | + if n.endswith('.block_sparse_moe.input_linear.weight'): |
| 313 | + for e in range(p.size(0)): |
| 314 | + w1_name = n.replace( |
| 315 | + '.block_sparse_moe.input_linear.weight', |
| 316 | + f".block_sparse_moe.experts.{e}.w1.weight") |
| 317 | + w3_name = n.replace( |
| 318 | + '.block_sparse_moe.input_linear.weight', |
| 319 | + f".block_sparse_moe.experts.{e}.w3.weight") |
| 320 | + w1_param, w3_param = p[e].chunk(2, dim=0) |
| 321 | + assert w1_name not in new_weights |
| 322 | + assert w3_name not in new_weights |
| 323 | + new_weights[w1_name] = w1_param |
| 324 | + new_weights[w3_name] = w3_param |
| 325 | + elif n.endswith('.block_sparse_moe.output_linear.weight'): |
| 326 | + for e in range(p.size(0)): |
| 327 | + w2_name = n.replace( |
| 328 | + '.block_sparse_moe.output_linear.weight', |
| 329 | + f".block_sparse_moe.experts.{e}.w2.weight") |
| 330 | + w2_param = p[e] |
| 331 | + assert w2_name not in new_weights |
| 332 | + new_weights[w2_name] = w2_param |
| 333 | + elif n.endswith('.block_sparse_moe.router.layer.weight'): |
| 334 | + gate_name = n.replace('.block_sparse_moe.router.layer.weight', |
| 335 | + ".block_sparse_moe.gate.weight") |
| 336 | + assert gate_name not in new_weights |
| 337 | + new_weights[gate_name] = p |
| 338 | + elif n == 'lm_head.weight' and self.config.tie_word_embeddings: |
| 339 | + pass |
| 340 | + else: |
| 341 | + new_weights[n] = p |
| 342 | + return mixtral.MixtralForCausalLM.load_weights(self, |
| 343 | + new_weights.items()) |
0 commit comments