Skip to content

[Misc] Allow AutoWeightsLoader to skip loading weights with specific substr in name #18358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions tests/models/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,73 @@ def weight_generator():
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1


def test_module_skip_prefix():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))

# Try to load the weights to a new instance
def weight_generator():
# weights needed to be filtered out
redundant_weights = {
"prefix.bn.weight": torch.Tensor([1, 2]),
"prefix.bn.bias": torch.Tensor([3, 4]),
}
yield from (mod.state_dict() | redundant_weights).items()

new_mod = ModuleWithNestedBatchNorm()

assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0

loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
loader.load_weights(weight_generator())

# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1


def test_module_skip_substr():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))

# Try to load the weights to a new instance
def weight_generator():
# weights needed to be filtered out
redundant_weights = {
"nested_mod.0.substr.weight": torch.Tensor([1, 2]),
"nested_mod.0.substr.bias": torch.Tensor([3, 4]),
"nested_mod.substr.weight": torch.Tensor([1, 2]),
"nested_mod.substr.bias": torch.Tensor([3, 4]),
}
yield from (mod.state_dict() | redundant_weights).items()

new_mod = ModuleWithNestedBatchNorm()

assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0

loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
loader.load_weights(weight_generator())

# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
12 changes: 8 additions & 4 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def make_empty_intermediate_tensors(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = [
skip_substrs = [
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
Expand All @@ -488,8 +488,12 @@ def load_weights(self, weights: Iterable[tuple[str,
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings:
skip_prefixes.append("lm_head.weight")
skip_prefixes = (["lm_head."]
if self.config.tie_word_embeddings else None)

loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(
self,
skip_prefixes=skip_prefixes,
skip_substrs=skip_substrs,
)
return loader.load_weights(weights)
12 changes: 8 additions & 4 deletions vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,10 +550,14 @@ def compute_logits(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = ["rotary_emb.inv_freq"]
skip_substrs = ["rotary_emb.inv_freq"]
# Skip lm_head when tie_word_embeddings is True
if self.config.tie_word_embeddings:
skip_prefixes.append("lm_head")
skip_prefixes = (["lm_head"]
if self.config.tie_word_embeddings else None)

loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(
self,
skip_prefixes=skip_prefixes,
skip_substrs=skip_substrs,
)
return loader.load_weights(weights)
2 changes: 1 addition & 1 deletion vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,5 +482,5 @@ def compute_logits(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
loader = AutoWeightsLoader(self, skip_substrs=["rotary_emb.inv_freq"])
return loader.load_weights(weights)
2 changes: 1 addition & 1 deletion vllm/model_executor/models/mixtral_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,6 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
skip_substrs=(["rotary_emb.inv_freq"]),
)
return loader.load_weights(weights)
4 changes: 2 additions & 2 deletions vllm/model_executor/models/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,12 +504,12 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
skip_substrs=[
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
],
)
return loader.load_weights(weights)
13 changes: 4 additions & 9 deletions vllm/model_executor/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,19 +382,14 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached",
"lm_head.weight"
] if self.config.tie_word_embeddings else [
skip_prefixes=(["lm_head.weight"]
if self.config.tie_word_embeddings else None),
skip_substrs=[
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
],
)
return loader.load_weights(weights)
13 changes: 4 additions & 9 deletions vllm/model_executor/models/olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,19 +403,14 @@ def compute_logits(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached",
"lm_head.weight"
] if self.config.tie_word_embeddings else [
skip_prefixes=(["lm_head.weight"]
if self.config.tie_word_embeddings else None),
skip_substrs=[
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
"rotary_emb.cos_cached",
"rotary_emb.sin_cached"
]),
],
)
return loader.load_weights(weights)
2 changes: 1 addition & 1 deletion vllm/model_executor/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,6 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=["rotary_emb.inv_freq"],
skip_substrs=["rotary_emb.inv_freq"],
)
return loader.load_weights(weights)
2 changes: 1 addition & 1 deletion vllm/model_executor/models/orion.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
skip_substrs=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/models/phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,9 +1228,7 @@ def compute_logits(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights
if "lora" not in name)
loader = AutoWeightsLoader(self)
loader = AutoWeightsLoader(self, skip_substrs=["lora"])
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

def get_mm_mapping(self) -> MultiModelKeys:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,6 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
skip_substrs=(["rotary_emb.inv_freq"]),
)
return loader.load_weights(weights)
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,6 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
skip_substrs=(["rotary_emb.inv_freq"]),
)
return loader.load_weights(weights)
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,6 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
skip_substrs=(["rotary_emb.inv_freq"]),
)
return loader.load_weights(weights)
2 changes: 1 addition & 1 deletion vllm/model_executor/models/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
skip_substrs=([
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def load_weights(self, weights: Iterable[tuple[str,
self,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
skip_prefixes=[
skip_substrs=[
"rotary_emb.inv_freq", "rotary_emb.cos_cached",
"rotary_emb.sin_cached"
],
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ def load_weights(self, weights: Iterable[tuple[str,
self,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
skip_prefixes=([
"rotary_emb.inv_freq", "lm_head.weight"
] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]),
skip_prefixes=(["lm_head.weight"]
if self.config.tie_word_embeddings else None),
skip_substrs=["rotary_emb.inv_freq"],
)
return loader.load_weights(weights)
8 changes: 7 additions & 1 deletion vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,14 @@ def __init__(
module: nn.Module,
*,
skip_prefixes: Optional[list[str]] = None,
skip_substrs: Optional[list[str]] = None,
ignore_unexpected_prefixes: Optional[list[str]] = None,
) -> None:
super().__init__()

self.module = module
self.skip_prefixes = skip_prefixes or []
self.skip_substrs = skip_substrs or []
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []

def _groupby_prefix(
Expand Down Expand Up @@ -119,7 +121,8 @@ def _get_qualname(self, prefix: str, rest: str) -> str:
return ".".join((prefix, rest))

def _can_skip(self, qualname: str) -> bool:
return any(qualname.startswith(p) for p in self.skip_prefixes)
return (any(qualname.startswith(p) for p in self.skip_prefixes)
or any(substr in qualname for substr in self.skip_substrs))

def _can_ignore_unexpected(self, qualname: str) -> bool:
return any(
Expand Down Expand Up @@ -257,6 +260,9 @@ def load_weights(
) -> set[str]:
if mapper is not None:
weights = mapper.apply(weights)
# filter out weights with first-prefix/substr to skip in name
weights = ((name, weight) for name, weight in weights
if not self._can_skip(name))

autoloaded_weights = set(self._load_module("", self.module, weights))
return autoloaded_weights
Expand Down