Skip to content

Commit 3c8ff80

Browse files
mobichamjimpang
authored andcommitted
Fix TorchAOConfig skip layers (vllm-project#19265)
Signed-off-by: mobicham <[email protected]>
1 parent 05cd0fb commit 3c8ff80

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

tests/quantization/test_torchao.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,20 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
6060
print(output)
6161

6262

63+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
64+
def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
65+
torch._dynamo.reset()
66+
model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao"
67+
with vllm_runner(model_name=model_name,
68+
quantization="torchao",
69+
dtype="bfloat16",
70+
pt_load_map_location="cuda:0") as llm:
71+
output = llm.generate_greedy(["The capital of France is"],
72+
max_tokens=32)
73+
74+
assert output
75+
print(output)
76+
77+
6378
if __name__ == "__main__":
6479
pytest.main([__file__])

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,30 @@
1717
logger = init_logger(__name__)
1818

1919

20+
def should_skip(prefix: str, skip_modules: list[str]) -> bool:
21+
"""
22+
Robust skipping logic:
23+
should_skip("model.model.layers.1.q_proj",
24+
["model.model.layers.1.q_proj"]) # True
25+
should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True
26+
should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True
27+
should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True
28+
should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False
29+
"""
30+
for s in skip_modules:
31+
if prefix == s:
32+
return True
33+
if f".{s}." in f".{prefix}.":
34+
return True
35+
return False
36+
37+
2038
class TorchAOConfig(QuantizationConfig):
2139
"""Config class for torchao."""
2240

23-
def __init__(self, torchao_config) -> None:
24-
self.torchao_config = torchao_config
41+
def __init__(self,
42+
torchao_config,
43+
skip_modules: Optional[list[str]] = None) -> None:
2544
"""
2645
# TorchAO quantization relies on tensor subclasses. In order,
2746
# to enable proper caching this needs standalone compile
@@ -36,6 +55,8 @@ def __init__(self, torchao_config) -> None:
3655
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
3756
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
3857
"""
58+
self.torchao_config = torchao_config
59+
self.skip_modules = skip_modules or []
3960

4061
def __repr__(self) -> str:
4162
return f"TorchAOConfig({self.torchao_config})"
@@ -67,11 +88,28 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
6788

6889
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
6990
assert hf_config is not None, "quant_type must be specified"
70-
assert (len(hf_config) == 1 and "default" in hf_config
71-
), "Expected only one key 'default' in quant_type dictionary"
91+
assert len(hf_config) == 1 and "default" in hf_config, (
92+
"Expected only one key 'default' in quant_type dictionary")
7293
quant_type = hf_config["default"]
7394
ao_config = config_from_dict(quant_type)
74-
return cls(ao_config)
95+
96+
# Adds skipped modules defined in "modules_to_not_convert"
97+
skip_modules = config.get("modules_to_not_convert", []) or []
98+
99+
# Adds skipped modules defined in "module_fqn_to_config"
100+
_data = quant_type.get("_data", {})
101+
if not isinstance(_data, dict):
102+
_data = {}
103+
104+
module_fqn = _data.get("module_fqn_to_config", {})
105+
if not isinstance(module_fqn, dict):
106+
module_fqn = {}
107+
108+
for layer, layer_cfg in module_fqn.items():
109+
if layer_cfg is None:
110+
skip_modules.append(layer)
111+
112+
return cls(ao_config, skip_modules)
75113

76114
def get_quant_method(self, layer: torch.nn.Module,
77115
prefix: str) -> Optional["QuantizeMethodBase"]:
@@ -80,13 +118,16 @@ def get_quant_method(self, layer: torch.nn.Module,
80118

81119
from torchao.quantization import ModuleFqnToConfig
82120

121+
if should_skip(prefix, self.skip_modules):
122+
return UnquantizedLinearMethod()
123+
83124
module_fqn = prefix
84125
if isinstance(self.torchao_config, ModuleFqnToConfig):
85126
module_fqn_to_config = self.torchao_config.module_fqn_to_config
86127
c = module_fqn_to_config.get(
87128
module_fqn) or module_fqn_to_config.get("_default", None)
88129
if c is not None:
89-
current_torchao_config = TorchAOConfig(c)
130+
current_torchao_config = TorchAOConfig(c, self.skip_modules)
90131
return TorchAOLinearMethod(current_torchao_config)
91132
else:
92133
return UnquantizedLinearMethod()
@@ -108,8 +149,17 @@ def torchao_quantize_param_data(param: torch.Tensor,
108149
"""
109150
from torchao.core.config import AOBaseConfig
110151
from torchao.quantization import quantize_
152+
111153
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
112-
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
154+
"""
155+
Avoid real weight allocation for faster load, since we will
156+
end up setting it to param.
157+
"""
158+
with torch.device("meta"):
159+
dummy_linear = torch.nn.Linear(param.shape[1],
160+
param.shape[0],
161+
bias=False)
162+
113163
dummy_linear.weight = param
114164
quantize_(dummy_linear, torchao_config)
115165
return dummy_linear.weight

0 commit comments

Comments
 (0)