17
17
logger = init_logger (__name__ )
18
18
19
19
20
+ def should_skip (prefix : str , skip_modules : list [str ]) -> bool :
21
+ """
22
+ Robust skipping logic:
23
+ should_skip("model.model.layers.1.q_proj",
24
+ ["model.model.layers.1.q_proj"]) # True
25
+ should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True
26
+ should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True
27
+ should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True
28
+ should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False
29
+ """
30
+ for s in skip_modules :
31
+ if prefix == s :
32
+ return True
33
+ if f".{ s } ." in f".{ prefix } ." :
34
+ return True
35
+ return False
36
+
37
+
20
38
class TorchAOConfig (QuantizationConfig ):
21
39
"""Config class for torchao."""
22
40
23
- def __init__ (self , torchao_config ) -> None :
24
- self .torchao_config = torchao_config
41
+ def __init__ (self ,
42
+ torchao_config ,
43
+ skip_modules : Optional [list [str ]] = None ) -> None :
25
44
"""
26
45
# TorchAO quantization relies on tensor subclasses. In order,
27
46
# to enable proper caching this needs standalone compile
@@ -36,6 +55,8 @@ def __init__(self, torchao_config) -> None:
36
55
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
37
56
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
38
57
"""
58
+ self .torchao_config = torchao_config
59
+ self .skip_modules = skip_modules or []
39
60
40
61
def __repr__ (self ) -> str :
41
62
return f"TorchAOConfig({ self .torchao_config } )"
@@ -67,11 +88,28 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
67
88
68
89
hf_config = cls .get_from_keys_or (config , ["quant_type" ], None )
69
90
assert hf_config is not None , "quant_type must be specified"
70
- assert ( len (hf_config ) == 1 and "default" in hf_config
71
- ), "Expected only one key 'default' in quant_type dictionary"
91
+ assert len (hf_config ) == 1 and "default" in hf_config , (
92
+ "Expected only one key 'default' in quant_type dictionary" )
72
93
quant_type = hf_config ["default" ]
73
94
ao_config = config_from_dict (quant_type )
74
- return cls (ao_config )
95
+
96
+ # Adds skipped modules defined in "modules_to_not_convert"
97
+ skip_modules = config .get ("modules_to_not_convert" , []) or []
98
+
99
+ # Adds skipped modules defined in "module_fqn_to_config"
100
+ _data = quant_type .get ("_data" , {})
101
+ if not isinstance (_data , dict ):
102
+ _data = {}
103
+
104
+ module_fqn = _data .get ("module_fqn_to_config" , {})
105
+ if not isinstance (module_fqn , dict ):
106
+ module_fqn = {}
107
+
108
+ for layer , layer_cfg in module_fqn .items ():
109
+ if layer_cfg is None :
110
+ skip_modules .append (layer )
111
+
112
+ return cls (ao_config , skip_modules )
75
113
76
114
def get_quant_method (self , layer : torch .nn .Module ,
77
115
prefix : str ) -> Optional ["QuantizeMethodBase" ]:
@@ -80,13 +118,16 @@ def get_quant_method(self, layer: torch.nn.Module,
80
118
81
119
from torchao .quantization import ModuleFqnToConfig
82
120
121
+ if should_skip (prefix , self .skip_modules ):
122
+ return UnquantizedLinearMethod ()
123
+
83
124
module_fqn = prefix
84
125
if isinstance (self .torchao_config , ModuleFqnToConfig ):
85
126
module_fqn_to_config = self .torchao_config .module_fqn_to_config
86
127
c = module_fqn_to_config .get (
87
128
module_fqn ) or module_fqn_to_config .get ("_default" , None )
88
129
if c is not None :
89
- current_torchao_config = TorchAOConfig (c )
130
+ current_torchao_config = TorchAOConfig (c , self . skip_modules )
90
131
return TorchAOLinearMethod (current_torchao_config )
91
132
else :
92
133
return UnquantizedLinearMethod ()
@@ -108,8 +149,17 @@ def torchao_quantize_param_data(param: torch.Tensor,
108
149
"""
109
150
from torchao .core .config import AOBaseConfig
110
151
from torchao .quantization import quantize_
152
+
111
153
assert isinstance (torchao_config , AOBaseConfig ), f"{ torchao_config } "
112
- dummy_linear = torch .nn .Linear (param .shape [1 ], param .shape [0 ], bias = False )
154
+ """
155
+ Avoid real weight allocation for faster load, since we will
156
+ end up setting it to param.
157
+ """
158
+ with torch .device ("meta" ):
159
+ dummy_linear = torch .nn .Linear (param .shape [1 ],
160
+ param .shape [0 ],
161
+ bias = False )
162
+
113
163
dummy_linear .weight = param
114
164
quantize_ (dummy_linear , torchao_config )
115
165
return dummy_linear .weight
0 commit comments