Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions verl/utils/vllm/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,43 @@ def patch_vllm_moe_model_weight_loader(model):
# (False, 'model.layers.0.mlp.experts.w13_weight') use mlp.experts.weight_loader
# (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader

# Define MLP attribute mapping for different model types
MLP_ATTR_MAPPING = {
MixtralForCausalLM: "block_sparse_moe",
}
DEFAULT_MLP_ATTR = "mlp"
# Early return if no MOE models are supported
if not SUPPORTED_MOE_MODELS:
return

if not isinstance(model, tuple(SUPPORTED_MOE_MODELS)):
return

model = getattr(model, "model", None) or getattr(model, "language_model", None)
if model is None:
original_model_type = type(model)

# Define MLP attribute mapping for different model types
MLP_ATTR_MAPPING = {}
try:
from vllm.model_executor.models.mixtral import MixtralForCausalLM

MLP_ATTR_MAPPING[MixtralForCausalLM] = "block_sparse_moe"
except ImportError:
pass

DEFAULT_MLP_ATTR = "mlp"

# Get inner model (either model.model or model.language_model)
inner_model = getattr(model, "model", None) or getattr(model, "language_model", None)
if inner_model is None:
raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.")

for layer in model.layers:
mlp_attr = MLP_ATTR_MAPPING.get(type(model), DEFAULT_MLP_ATTR)
mlp = getattr(layer, mlp_attr)
for layer_idx, layer in enumerate(inner_model.layers):
mlp_attr = MLP_ATTR_MAPPING.get(original_model_type, DEFAULT_MLP_ATTR)

mlp = getattr(layer, mlp_attr, None)
if not mlp:
continue

experts = getattr(mlp, "experts", None)
if not experts or not hasattr(experts, "weight_loader"):
continue

param_dict = dict(mlp.named_parameters())
for name, param in param_dict.items():
# Patch the weight loaders
for name, param in mlp.named_parameters():
if "w13_weight" in name or "w2_weight" in name:
param.weight_loader = mlp.experts.weight_loader
param.weight_loader = experts.weight_loader
Loading