Skip to content

[Quantization] add BNB for MixtralForCausalLM #20893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
"fp8",
"compressed-tensors",
"gptq_marlin",
"awq_marlin",
"quark",
"bitsandbytes",
]
Comment on lines +235 to 236
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding bnb to mixtral_supported for consistency and clarity.

        "quark",
        "bnb",


vllm_supported_archs = ModelRegistry.get_supported_archs()
Expand Down
105 changes: 102 additions & 3 deletions vllm/model_executor/models/granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from . import mixtral
from .interfaces import SupportsLoRA, SupportsPP
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_layers,
maybe_prefix)


class GraniteMoeMoE(nn.Module):
Expand Down Expand Up @@ -307,6 +309,103 @@ def forward(
hidden_states = self.norm(hidden_states)
return hidden_states

def _load_weights(self,
weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""
This function is copied from `MixtralModel.load_weights`, mainly to
decouple from mixtral, avoiding impact on support like BNB
quantization.
"""
Comment on lines +312 to +318
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function is copied from MixtralModel.load_weights to decouple from mixtral, but this introduces code duplication. Consider refactoring the weight loading logic into a shared utility function or a mixin to reduce duplication and improve maintainability.

stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
new_weights = {}
Expand Down Expand Up @@ -339,7 +438,7 @@ def load_weights(self, weights: Iterable[tuple[str,
new_weights[gate_name] = p
else:
new_weights[n] = p
return mixtral.MixtralModel.load_weights(self, new_weights.items())
return self._load_weights(new_weights.items())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider calling self._load_weights instead of mixtral.MixtralModel.load_weights to avoid direct dependency on the mixtral module.



class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/models/granitemoeshared.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from . import mixtral
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
from .interfaces import SupportsLoRA, SupportsPP
from .utils import AutoWeightsLoader, make_layers, maybe_prefix

Expand Down Expand Up @@ -242,7 +241,7 @@ def load_weights(self, weights: Iterable[tuple[str,
new_weights[gate_name] = p
else:
new_weights[n] = p
return mixtral.MixtralModel.load_weights(self, new_weights.items())
return GraniteMoeModel._load_weights(self, new_weights.items())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider calling self._load_weights instead of GraniteMoeModel._load_weights to avoid direct dependency on the GraniteMoeModel module.



class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Expand Down
21 changes: 13 additions & 8 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,15 @@ def forward(
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
Expand All @@ -326,16 +335,9 @@ def load_weights(self, weights: Iterable[tuple[str,
("qkv_proj", "v_proj", "v"),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
Expand Down Expand Up @@ -486,3 +488,6 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
3 changes: 2 additions & 1 deletion vllm/model_executor/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def load_weights(self, weights: Iterable[tuple[str,

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
Expand Down Expand Up @@ -380,7 +381,7 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in self.get_expert_mapping():
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def load_weights(self, weights: Iterable[tuple[str,

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
Expand Down Expand Up @@ -442,7 +443,7 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in self.get_expert_mapping():
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,9 @@ def load_weights(self, weights: Iterable[tuple[str,
".v_scale", "_v_scale", ".weight_scale",
"_weight_scale", ".input_scale", "_input_scale")

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
Expand Down