3
3
from typing import Dict , List , Optional , Tuple , Union
4
4
5
5
import torch
6
- from compressed_tensors .quantization import disable_quantization
6
+ from compressed_tensors .quantization import (
7
+ QuantizationConfig ,
8
+ QuantizationScheme ,
9
+ disable_quantization ,
10
+ )
11
+ from compressed_tensors .quantization .quant_args import ActivationOrdering
7
12
from compressed_tensors .utils import (
8
13
align_module_device ,
9
14
get_execution_device ,
@@ -39,6 +44,7 @@ class GPTQModifier(Modifier, QuantizationMixin):
39
44
| block_size: 128
40
45
| dampening_frac: 0.001
41
46
| offload_hessians: False
47
+ | actorder: static
42
48
| config_groups:
43
49
| group_0:
44
50
| targets:
@@ -51,7 +57,6 @@ class GPTQModifier(Modifier, QuantizationMixin):
51
57
| symmetric: true
52
58
| strategy: group
53
59
| group_size: 128
54
- | actorder: False
55
60
56
61
Lifecycle:
57
62
- on_initialize
@@ -70,6 +75,8 @@ class GPTQModifier(Modifier, QuantizationMixin):
70
75
:param block_size: Used to determine number of columns to compress in one pass
71
76
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
72
77
diagonal norm
78
+ :param actorder: order in which weight columns are quantized. For more information,
79
+ on actorder options, see https://github.com/vllm-project/vllm/pull/8135
73
80
:param offload_hessians: Set to True for decreased memory usage but increased
74
81
runtime.
75
82
@@ -102,6 +109,7 @@ class GPTQModifier(Modifier, QuantizationMixin):
102
109
sequential_targets : Union [str , List [str ], None ] = None
103
110
block_size : int = 128
104
111
dampening_frac : Optional [float ] = 0.01
112
+ actorder : Optional [ActivationOrdering ] = None
105
113
offload_hessians : bool = False
106
114
107
115
# private variables
@@ -120,6 +128,29 @@ def validate_sequential_update(cls, value: bool) -> bool:
120
128
121
129
return True
122
130
131
+ def resolve_quantization_config (self ) -> QuantizationConfig :
132
+ config = super ().resolve_quantization_config ()
133
+
134
+ # Resolve config with `self.actorder`
135
+ for scheme in config .config_groups .values ():
136
+ assert isinstance (scheme , QuantizationScheme ) # (1)
137
+ if scheme .weights is not None :
138
+ existing = scheme .weights .actorder
139
+ assert isinstance (existing , (ActivationOrdering , type (None ))) # (2)
140
+ if existing is not None and existing != self .actorder :
141
+ raise ValueError (
142
+ "Cannot resolve activation ordering when both "
143
+ "`GPTQModifier.actorder` and `QuantizationScheme.actorder` "
144
+ "both are provided. Either set `GPTQModifier.actorder = None` "
145
+ "or remove `actorder` from config groups"
146
+ )
147
+ scheme .weights .actorder = self .actorder
148
+
149
+ # (1) QuantizationConfig.model_post_init
150
+ # (2) QuantizationScheme.validate_actorder
151
+
152
+ return config
153
+
123
154
def on_initialize (self , state : State , ** kwargs ) -> bool :
124
155
"""
125
156
Initialize and run the GPTQ algorithm on the current state
@@ -176,31 +207,6 @@ def on_event(self, state: State, event: Event, **kwargs):
176
207
if not self .ended_ :
177
208
self .on_end (state , None )
178
209
179
- def on_end (self , state : State , event : Event , ** kwargs ):
180
- """
181
- Finish calibrating by removing observers and calibration hooks
182
- """
183
- self .ended_ = True
184
- QuantizationMixin .end_calibration (self , state .model )
185
- self .remove_hooks () # remove gptq hooks
186
-
187
- def on_finalize (self , state : State , ** kwargs ) -> bool :
188
- """
189
- disable the quantization observers used by the OBCQ algorithm
190
-
191
- :param state: session state storing input model and calibration data
192
- """
193
- if not self .ended_ :
194
- self .on_end (state , None )
195
-
196
- if len (self ._num_samples ) > 0 :
197
- raise ValueError (f"Failed to compress { len (self ._num_samples )} modules" )
198
-
199
- self ._hessians = dict ()
200
- self ._num_samples = dict ()
201
-
202
- return True
203
-
204
210
def calibrate_module (
205
211
self ,
206
212
module : torch .nn .Module ,
@@ -268,6 +274,31 @@ def compress_modules(self):
268
274
# self._hessians[module] already deleted by quantize_weight
269
275
del self ._num_samples [module ]
270
276
277
+ def on_end (self , state : State , event : Event , ** kwargs ):
278
+ """
279
+ Finish calibrating by removing observers and calibration hooks
280
+ """
281
+ self .ended_ = True
282
+ QuantizationMixin .end_calibration (self , state .model )
283
+ self .remove_hooks () # remove gptq hooks
284
+
285
+ def on_finalize (self , state : State , ** kwargs ) -> bool :
286
+ """
287
+ disable the quantization observers used by the OBCQ algorithm
288
+
289
+ :param state: session state storing input model and calibration data
290
+ """
291
+ if not self .ended_ :
292
+ self .on_end (state , None )
293
+
294
+ if len (self ._num_samples ) > 0 :
295
+ raise ValueError (f"Failed to compress { len (self ._num_samples )} modules" )
296
+
297
+ self ._hessians = dict ()
298
+ self ._num_samples = dict ()
299
+
300
+ return True
301
+
271
302
@contextlib .contextmanager
272
303
def _maybe_onload_hessian (self , module : torch .nn .Module ):
273
304
if self .offload_hessians :
0 commit comments