Skip to content

Commit c060b71

Browse files
[Model] Add support for GraniteMoeShared models (#13313)
Signed-off-by: Travis Johnson <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 79e4937 commit c060b71

File tree

4 files changed

+351
-0
lines changed

4 files changed

+351
-0
lines changed

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@ See [this page](#generative-models) for more information on how to use generativ
298298
* `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc.
299299
* ✅︎
300300
* ✅︎
301+
- * `GraniteMoeSharedForCausalLM`
302+
* Granite MoE Shared
303+
* `ibm-research/moe-7b-1b-active-shared-experts` (test model)
304+
* ✅︎
305+
* ✅︎
301306
- * `GritLM`
302307
* GritLM
303308
* `parasail-ai/GritLM-7B-vllm`.

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def check_available_online(
131131
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"),
132132
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
133133
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
134+
"GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts", # noqa: E501
135+
min_transformers_version="4.49"), # noqa: E501
134136
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
135137
trust_remote_code=True),
136138
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Inference-only GraniteMoeShared model.
3+
4+
The architecture is the same as granitemoe but with the addition of shared
5+
experts.
6+
"""
7+
from typing import Iterable, Optional, Set, Tuple
8+
9+
import torch
10+
from torch import nn
11+
from transformers.models.granitemoeshared import GraniteMoeSharedConfig
12+
13+
from vllm.compilation.decorators import support_torch_compile
14+
from vllm.config import CacheConfig, VllmConfig
15+
from vllm.distributed import get_pp_group
16+
from vllm.model_executor.layers.activation import SiluAndMul
17+
from vllm.model_executor.layers.layernorm import RMSNorm
18+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
19+
RowParallelLinear)
20+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21+
from vllm.model_executor.layers.quantization.base_config import (
22+
QuantizationConfig)
23+
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24+
from vllm.model_executor.layers.vocab_parallel_embedding import (
25+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
26+
from vllm.model_executor.sampling_metadata import SamplingMetadata
27+
from vllm.sequence import IntermediateTensors
28+
29+
from . import mixtral
30+
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE
31+
from .interfaces import SupportsLoRA, SupportsPP
32+
from .utils import make_layers, maybe_prefix
33+
34+
35+
class GraniteMoeSharedMLP(nn.Module):
36+
37+
def __init__(
38+
self,
39+
config: GraniteMoeSharedConfig,
40+
quant_config: Optional[QuantizationConfig] = None,
41+
prefix: str = "",
42+
):
43+
super().__init__()
44+
45+
self.input_size = config.hidden_size
46+
self.hidden_size = config.shared_intermediate_size
47+
self.input_linear = MergedColumnParallelLinear(
48+
input_size=self.input_size,
49+
output_sizes=[self.hidden_size] * 2,
50+
bias=False,
51+
quant_config=quant_config,
52+
prefix=f"{prefix}.input_linear")
53+
self.output_linear = RowParallelLinear(
54+
self.hidden_size,
55+
self.input_size,
56+
bias=False,
57+
quant_config=quant_config,
58+
prefix=f"{prefix}.output_linear")
59+
if config.hidden_act != "silu":
60+
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
61+
"Only silu is supported for now.")
62+
self.act_fn = SiluAndMul()
63+
64+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
65+
hidden_states, _ = self.input_linear(hidden_states)
66+
hidden_states = self.act_fn(hidden_states)
67+
hidden_states, _ = self.output_linear(hidden_states)
68+
return hidden_states
69+
70+
71+
class GraniteMoeSharedDecoderLayer(nn.Module):
72+
73+
def __init__(
74+
self,
75+
config: GraniteMoeSharedConfig,
76+
cache_config: Optional[CacheConfig] = None,
77+
quant_config: Optional[QuantizationConfig] = None,
78+
prefix: str = "",
79+
) -> None:
80+
super().__init__()
81+
self.hidden_size = config.hidden_size
82+
# Requires transformers > 4.32.0
83+
rope_theta = getattr(config, "rope_theta", 10000)
84+
self.self_attn = GraniteMoeAttention(
85+
hidden_size=self.hidden_size,
86+
num_heads=config.num_attention_heads,
87+
max_position=config.max_position_embeddings,
88+
num_kv_heads=config.num_key_value_heads,
89+
rope_theta=rope_theta,
90+
cache_config=cache_config,
91+
quant_config=quant_config,
92+
prefix=f"{prefix}.self_attn",
93+
attention_multiplier=config.attention_multiplier)
94+
self.block_sparse_moe = GraniteMoeMoE(
95+
num_experts=config.num_local_experts,
96+
top_k=config.num_experts_per_tok,
97+
hidden_size=config.hidden_size,
98+
intermediate_size=config.intermediate_size,
99+
quant_config=quant_config,
100+
prefix=f"{prefix}.block_sparse_moe")
101+
self.shared_mlp = None if \
102+
getattr(config, 'shared_intermediate_size', 0) == 0 \
103+
else GraniteMoeSharedMLP(
104+
config,
105+
quant_config=quant_config,
106+
prefix=f"{prefix}.shared_mlp"
107+
)
108+
109+
self.input_layernorm = RMSNorm(config.hidden_size,
110+
eps=config.rms_norm_eps)
111+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
112+
eps=config.rms_norm_eps)
113+
114+
self.residual_multiplier = config.residual_multiplier
115+
116+
def forward(
117+
self,
118+
positions: torch.Tensor,
119+
hidden_states: torch.Tensor,
120+
) -> torch.Tensor:
121+
# Self Attention
122+
residual = hidden_states
123+
hidden_states = self.input_layernorm(hidden_states)
124+
hidden_states = self.self_attn(
125+
positions=positions,
126+
hidden_states=hidden_states,
127+
)
128+
hidden_states = residual + hidden_states * self.residual_multiplier
129+
residual = hidden_states
130+
hidden_states = self.post_attention_layernorm(hidden_states)
131+
if self.shared_mlp is None:
132+
hidden_states = self.block_sparse_moe(hidden_states)
133+
else:
134+
# create a copy since block_sparse_moe modifies in-place
135+
moe_hidden_states = hidden_states.clone()
136+
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
137+
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
138+
del moe_hidden_states
139+
hidden_states = residual + hidden_states * self.residual_multiplier
140+
141+
return hidden_states
142+
143+
144+
@support_torch_compile
145+
class GraniteMoeSharedModel(nn.Module):
146+
147+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
148+
super().__init__()
149+
150+
config = vllm_config.model_config.hf_config
151+
cache_config = vllm_config.cache_config
152+
quant_config = vllm_config.quant_config
153+
lora_config = vllm_config.lora_config
154+
155+
self.padding_idx = config.pad_token_id
156+
lora_vocab = (lora_config.lora_extra_vocab_size *
157+
(lora_config.max_loras or 1)) if lora_config else 0
158+
self.vocab_size = config.vocab_size + lora_vocab
159+
self.org_vocab_size = config.vocab_size
160+
161+
self.embed_tokens = VocabParallelEmbedding(
162+
self.vocab_size,
163+
config.hidden_size,
164+
org_num_embeddings=config.vocab_size,
165+
quant_config=quant_config,
166+
)
167+
self.embedding_multiplier = config.embedding_multiplier
168+
169+
self.start_layer, self.end_layer, self.layers = make_layers(
170+
config.num_hidden_layers,
171+
lambda prefix: GraniteMoeSharedDecoderLayer(
172+
config, cache_config, quant_config=quant_config, prefix=prefix
173+
),
174+
prefix=f"{prefix}.layers")
175+
176+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
177+
178+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
179+
return self.embed_tokens(input_ids)
180+
181+
def forward(
182+
self,
183+
input_ids: torch.Tensor,
184+
positions: torch.Tensor,
185+
intermediate_tensors: Optional[IntermediateTensors],
186+
inputs_embeds: Optional[torch.Tensor] = None,
187+
) -> torch.Tensor:
188+
if get_pp_group().is_first_rank:
189+
if inputs_embeds is not None:
190+
hidden_states = inputs_embeds
191+
else:
192+
hidden_states = self.get_input_embeddings(input_ids)
193+
hidden_states *= self.embedding_multiplier
194+
residual = None
195+
else:
196+
assert intermediate_tensors is not None
197+
hidden_states = intermediate_tensors["hidden_states"]
198+
residual = intermediate_tensors["residual"]
199+
for i in range(self.start_layer, self.end_layer):
200+
layer = self.layers[i]
201+
hidden_states = layer(positions, hidden_states)
202+
if not get_pp_group().is_last_rank:
203+
return IntermediateTensors({
204+
"hidden_states": hidden_states,
205+
"residual": residual
206+
})
207+
hidden_states = self.norm(hidden_states)
208+
return hidden_states
209+
210+
211+
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
212+
fall_back_to_pt_during_load = False
213+
214+
packed_modules_mapping = {
215+
"qkv_proj": [
216+
"q_proj",
217+
"k_proj",
218+
"v_proj",
219+
],
220+
}
221+
222+
# LoRA specific attributes
223+
embedding_modules = {
224+
"embed_tokens": "input_embeddings",
225+
"lm_head": "output_embeddings",
226+
}
227+
embedding_padding_modules = ["lm_head"]
228+
229+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
230+
super().__init__()
231+
config = vllm_config.model_config.hf_config
232+
quant_config = vllm_config.quant_config
233+
lora_config = vllm_config.lora_config
234+
235+
self.config = config
236+
self.lora_config = lora_config
237+
self.quant_config = quant_config
238+
239+
self.model = GraniteMoeSharedModel(vllm_config=vllm_config,
240+
prefix=maybe_prefix(
241+
prefix, "model"))
242+
self.unpadded_vocab_size = config.vocab_size
243+
if lora_config:
244+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
245+
self.lm_head = ParallelLMHead(
246+
self.unpadded_vocab_size,
247+
config.hidden_size,
248+
org_num_embeddings=config.vocab_size,
249+
padding_size=DEFAULT_VOCAB_PADDING_SIZE
250+
# We need bigger padding if using lora for kernel
251+
# compatibility
252+
if not lora_config else lora_config.lora_vocab_padding_size,
253+
quant_config=quant_config,
254+
prefix=maybe_prefix(prefix, "lm_head"))
255+
if config.tie_word_embeddings:
256+
self.lm_head.weight = self.model.embed_tokens.weight
257+
258+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
259+
config.vocab_size,
260+
scale=1 /
261+
self.config.logits_scaling)
262+
263+
self.sampler = get_sampler()
264+
265+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
266+
return self.model.get_input_embeddings(input_ids)
267+
268+
def forward(
269+
self,
270+
input_ids: torch.Tensor,
271+
positions: torch.Tensor,
272+
intermediate_tensors: Optional[IntermediateTensors] = None,
273+
inputs_embeds: Optional[torch.Tensor] = None,
274+
) -> torch.Tensor:
275+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
276+
inputs_embeds)
277+
return hidden_states
278+
279+
def compute_logits(
280+
self, hidden_states: torch.Tensor,
281+
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
282+
logits = self.logits_processor(self.lm_head, hidden_states,
283+
sampling_metadata)
284+
return logits
285+
286+
def make_empty_intermediate_tensors(
287+
self, batch_size: int, dtype: torch.dtype,
288+
device: torch.device) -> IntermediateTensors:
289+
return IntermediateTensors({
290+
"hidden_states":
291+
torch.zeros((batch_size, self.config.hidden_size),
292+
dtype=dtype,
293+
device=device),
294+
"residual":
295+
torch.zeros((batch_size, self.config.hidden_size),
296+
dtype=dtype,
297+
device=device),
298+
})
299+
300+
def sample(
301+
self,
302+
logits: Optional[torch.Tensor],
303+
sampling_metadata: SamplingMetadata,
304+
) -> Optional[SamplerOutput]:
305+
next_tokens = self.sampler(logits, sampling_metadata)
306+
return next_tokens
307+
308+
def load_weights(self, weights: Iterable[Tuple[str,
309+
torch.Tensor]]) -> Set[str]:
310+
new_weights = {}
311+
for n, p in weights:
312+
if n.endswith('.block_sparse_moe.input_linear.weight'):
313+
for e in range(p.size(0)):
314+
w1_name = n.replace(
315+
'.block_sparse_moe.input_linear.weight',
316+
f".block_sparse_moe.experts.{e}.w1.weight")
317+
w3_name = n.replace(
318+
'.block_sparse_moe.input_linear.weight',
319+
f".block_sparse_moe.experts.{e}.w3.weight")
320+
w1_param, w3_param = p[e].chunk(2, dim=0)
321+
assert w1_name not in new_weights
322+
assert w3_name not in new_weights
323+
new_weights[w1_name] = w1_param
324+
new_weights[w3_name] = w3_param
325+
elif n.endswith('.block_sparse_moe.output_linear.weight'):
326+
for e in range(p.size(0)):
327+
w2_name = n.replace(
328+
'.block_sparse_moe.output_linear.weight',
329+
f".block_sparse_moe.experts.{e}.w2.weight")
330+
w2_param = p[e]
331+
assert w2_name not in new_weights
332+
new_weights[w2_name] = w2_param
333+
elif n.endswith('.block_sparse_moe.router.layer.weight'):
334+
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
335+
".block_sparse_moe.gate.weight")
336+
assert gate_name not in new_weights
337+
new_weights[gate_name] = p
338+
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
339+
pass
340+
else:
341+
new_weights[n] = p
342+
return mixtral.MixtralForCausalLM.load_weights(self,
343+
new_weights.items())

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
6161
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
6262
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
63+
"GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501
6364
"GritLM": ("gritlm", "GritLM"),
6465
"Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
6566
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),

0 commit comments

Comments
 (0)