1616# limitations under the License.
1717#
1818
19- import copy
20- from typing import Any
19+ from typing import Any , List
2120
2221import torch
2322
3635class TrainableEquivalentTransformation :
3736 """Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""
3837
39- _PREPARE_ATTRS : list [str ] = ["weight_config" , "trained_alphas" ]
38+ _PREPARE_ATTRS : List [str ] = ["weight_config" , "trained_alphas" ]
4039 _PREPARE_ATTRS_PREFIX = "_prepare_"
4140
42- def __init__ (self , model , weight_config = {}, absorb_to_layer = {} , folding = True , example_inputs = None ):
41+ def __init__ (self , model , weight_config = {}, absorb_to_layer = None , folding = True , example_inputs = None ):
4342 """
4443 :param model: the model for quantization
4544 :param weight_config (dict, optional): contains all info required by RTN. Defaults to {}.
@@ -54,6 +53,24 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
5453 self .absorb_to_layer = absorb_to_layer
5554 self ._post_initialized = False
5655
56+ def _detect_absorb_to_layer (self , model , folding , example_inputs ):
57+ # If user not provide the layers to absorb the quantization, detect layers automatically
58+ supported_layers = ["Linear" ]
59+ detected_absorb_layers = {}
60+ # Detect the layers that can be absorbed automatically
61+ if folding :
62+ from neural_compressor .torch .algorithms .weight_only .utility import GraphTrace
63+
64+ tg = GraphTrace ()
65+ detected_absorb_layers , _ = tg .get_absorb_to_layer (model , example_inputs , supported_layers )
66+ else : # pragma: no cover
67+ for name , module in model .named_modules ():
68+ if module .__class__ .__name__ in supported_layers :
69+ detected_absorb_layers [name ] = [name ]
70+ logger .info ("Detected **absorb layer**: **absorbed layers**" )
71+ logger .info (detected_absorb_layers )
72+ return detected_absorb_layers
73+
5774 def _post_init (self ):
5875 self .dtype = self ._get_dtype ()
5976 self .model .to (self .device )
@@ -75,6 +92,8 @@ def add_tuning_scale(self, sqrt_w_init=False):
7592 to the paper for more details
7693 :param sqrt_w_init: use sqrt weight to init."""
7794
95+ if not self .absorb_to_layer :
96+ self .absorb_to_layer = self ._detect_absorb_to_layer (self .model , self .folding , self .example_inputs )
7897 if not self ._post_initialized :
7998 self ._post_init ()
8099 # freeze model.
@@ -104,7 +123,7 @@ def add_tuning_scale(self, sqrt_w_init=False):
104123
105124 self .trained_alphas [layer_norm ] = alpha
106125 for layer_name in self .absorb_to_layer [layer_norm ]:
107- if self .weight_config .get (layer_name ) is None : # pragma: no cover
126+ if not self .weight_config .get (layer_name ): # pragma: no cover
108127 logger .info (f"layer { layer_name } not in weight config, skip." )
109128 continue
110129 num_bits = self .weight_config [layer_name ]["bits" ]
@@ -117,10 +136,10 @@ def add_tuning_scale(self, sqrt_w_init=False):
117136 )
118137 set_module (self .model , layer_name , wrapper_module )
119138
120- for n , m in self .model .named_modules ():
139+ for layer_name , m in self .model .named_modules ():
121140 if isinstance (m , torch .nn .Linear ) and "orig_layer" not in n :
122- if self .weight_config .get (n ) is None : # pragma: no cover
123- logger .info (f"out of absorbed layer { n } not in weight config, skip." )
141+ if not self .weight_config .get (layer_name ) : # pragma: no cover
142+ logger .info (f"out of absorbed layer { layer_name } not in weight config, skip." )
124143 continue
125144 num_bits = self .weight_config [layer_name ]["bits" ]
126145 group_size = self .weight_config [layer_name ]["group_size" ]
@@ -131,7 +150,7 @@ def add_tuning_scale(self, sqrt_w_init=False):
131150 wrapper_module = TEQLinearFakeQuant (
132151 orig_layer = m , alpha = alpha , num_bits = num_bits , group_size = group_size , scheme = scheme
133152 )
134- set_module (self .model , n , wrapper_module )
153+ set_module (self .model , layer_name , wrapper_module )
135154 # Attach the weight config captured at prepare stage to the model
136155 self .model ._weight_config = self .weight_config
137156 self .model ._trained_alphas = self .trained_alphas
@@ -190,7 +209,9 @@ def _absorb_scales(self, layer, scale, layer_name=""):
190209 scale = scale .view (scale .shape [0 ], 1 )
191210 layer .weight *= scale
192211
193- elif layer .__class__ .__name__ == "LlamaRMSNorm" or layer .__class__ .__name__ == "T5LayerNorm" : ##quite tricky
212+ elif (
213+ layer .__class__ .__name__ == "LlamaRMSNorm" or layer .__class__ .__name__ == "T5LayerNorm"
214+ ): # pragma: no cover
194215 layer .weight *= scale
195216
196217 else : # pragma: no cover
@@ -222,7 +243,7 @@ def _scale_layer_weight(self, layer, scale): ##input channel
222243 @torch .no_grad ()
223244 def transform (self ):
224245 """Apply alpha/scale."""
225- if not self ._post_initialized :
246+ if not self ._post_initialized : # pragma: no cover
226247 self ._post_init ()
227248 for ln_name , layer_names in self .absorb_to_layer .items ():
228249 module = get_module (self .model , ln_name )
@@ -272,7 +293,7 @@ def save(self, save_scale_file="", save_state_dict_file=""):
272293
273294class TEQuantizer (Quantizer ):
274295
275- def __init__ (self , quant_config , folding , absorb_to_layer , example_inputs ):
296+ def __init__ (self , quant_config , folding , example_inputs , absorb_to_layer = None ):
276297 super ().__init__ (quant_config = quant_config )
277298 self .folding = folding
278299 self .absorb_to_layer = absorb_to_layer
0 commit comments