Skip to content

Commit ee44252

Browse files
kylesayrsdsikka
andcommitted
GPTQ: Depreciate non-sequential update option (#762)
* remove from gptq, apply style * remove instances of sequential_update argument in GPTQ tests * update examples * update example tests * documentation, remove from example * apply style * revert back to auto type * apply style --------- Co-authored-by: Dipika Sikka <[email protected]> Signed-off-by: Kyle Sayers <[email protected]>
1 parent e47bfa8 commit ee44252

28 files changed

+44
-88
lines changed

examples/big_models_with_accelerate/README.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ will work properly out of the box for basic quantization with `QuantizationModif
2929
even for CPU offloaded models.
3030

3131
To enable CPU offloading for second-order quantization methods such as GPTQ, we need to
32-
allocate additional memory upfront when computing the device map. Note that this
33-
device map will only compatible with `GPTQModifier(sequential_update=True, ...)`
32+
allocate additional memory upfront when computing the device map. Not doing so risks
33+
potentially going out-of-memory.
3434

3535
```python
3636
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map
@@ -48,12 +48,7 @@ model = SparseAutoModelForCausalLM.from_pretrained(
4848

4949
### Practical Advice
5050

51-
When working with `accelerate`, it is important to keep in mind that CPU offloading and naive pipeline-parallelism will slow down forward passes through the model. As a result, we need to take care to ensure that the quantization methods used fit well with the offloading scheme as methods that require many forward passes though the model will be slowed down.
52-
53-
General rules of thumb:
54-
- CPU offloading is best used with data-free quantization methods (e.g. PTQ with `FP8_DYNAMIC`)
55-
- Multi-GPU is fast enough to be used with calibration data-based methods with `sequential_update=False`
56-
- It is possible to use Multi-GPU with `sequential_update=True` to save GPU memory, but the runtime will be slower
51+
When working with `accelerate`, it is important to keep in mind that CPU offloading and naive pipeline-parallelism will slow down forward passes through the model. As a result, we need to take care to ensure that the quantization methods used fit well with the offloading scheme as methods that require many forward passes though the model will be slowed down. If more gpu memory is not available, consider reducing the precision of the loaded model to a lower-width dtype such as `torch.bfloat16`.
5752

5853
## Examples
5954

examples/big_models_with_accelerate/multi_gpu_int8_sequential_update.py renamed to examples/big_models_with_accelerate/mult_gpus_int8_device_map.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
1111

1212
# adjust based off number of desired GPUs
13+
# reserve_for_hessians=True reserves memory which is required by
14+
# GPTQModifier and SparseGPTModifier
1315
device_map = calculate_offload_device_map(
14-
MODEL_ID, reserve_for_hessians=True, num_gpus=2, torch_dtype=torch.bfloat16
16+
MODEL_ID, num_gpus=2, reserve_for_hessians=True, torch_dtype=torch.bfloat16
1517
)
1618

1719
model = SparseAutoModelForCausalLM.from_pretrained(
@@ -60,7 +62,9 @@ def tokenize(sample):
6062
recipe = [
6163
SmoothQuantModifier(smoothing_strength=0.8),
6264
GPTQModifier(
63-
targets="Linear", scheme="W8A8", ignore=["lm_head"], sequential_update=True
65+
targets="Linear",
66+
scheme="W8A8",
67+
ignore=["lm_head"],
6468
),
6569
]
6670

examples/big_models_with_accelerate/multi_gpu_int8.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,13 @@ def tokenize(sample):
5858
# 3) Configure algorithms. In this case, we:
5959
# * quantize the weights to int8 with GPTQ (static per channel)
6060
# * quantize the activations to int8 (dynamic per token)
61-
# * run non-sequentially (for seq update, see multi_gpu_int8_sequential_update.py)
6261
recipe = [
63-
GPTQModifier(
64-
targets="Linear", scheme="W8A8", ignore=["lm_head"], sequential_update=False
65-
),
62+
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
6663
]
6764

6865
# 4) Apply algorithms and save in `compressed-tensors` format.
66+
# if you encounter GPU out-of-memory issues, consider using an explicit
67+
# device map (see multi_gpus_int8_device_map.py)
6968
oneshot(
7069
model=model,
7170
tokenizer=tokenizer,

examples/quantization_24_sparse_w4a16/2:4_w4a16_group-128_recipe.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ quantization_stage:
2323
run_type: oneshot
2424
quantization_modifiers:
2525
GPTQModifier:
26-
sequential_update: true
2726
ignore: ["lm_head"]
2827
config_groups:
2928
group_0:

examples/quantization_24_sparse_w4a16/2:4_w4a16_recipe.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ quantization_stage:
2323
run_type: oneshot
2424
quantization_modifiers:
2525
GPTQModifier:
26-
sequential_update: true
2726
ignore: ["lm_head"]
2827
config_groups:
2928
group_0:

examples/quantization_w8a8_int8/gemma2_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def tokenize(sample):
5555
# 3) Select quantization algorithms. In this case, we:
5656
# * quantize the weights to int8 with GPTQ (static per channel)
5757
# * quantize the activations to int8 (dynamic per token)
58-
# Note: set sequential_update: true in the recipe to reduce memory
5958
recipe = GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"])
6059

6160
# 4) Apply quantization and save to disk compressed.

examples/quantization_w8a8_int8/llama3_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def tokenize(sample):
5757
# * apply SmoothQuant to make the activations easier to quantize
5858
# * quantize the weights to int8 with GPTQ (static per channel)
5959
# * quantize the activations to int8 (dynamic per token)
60-
# Note: set sequential_update: true in the recipe to reduce memory
6160
recipe = [
6261
SmoothQuantModifier(smoothing_strength=0.8),
6362
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),

examples/quantizing_moe/deepseek_moe_w8a8_int8.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def tokenize(sample):
7070
targets="Linear",
7171
scheme="W8A8",
7272
ignore=["lm_head", "re:.*mlp.gate$"],
73-
sequential_update=True,
7473
),
7574
]
7675

examples/quantizing_moe/deepseek_recipe_w4a16.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
quant_stage:
22
quant_modifiers:
33
GPTQModifier:
4-
sequential_update: true
54
ignore: [lm_head, "re:.*mlp.gate$"]
65
config_groups:
76
group_0:

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import gc
1+
import warnings
22
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
33

44
import torch
@@ -49,7 +49,6 @@ class GPTQModifier(Modifier):
4949
| test_stage:
5050
| obcq_modifiers:
5151
| GPTQModifier:
52-
| sequential_update: true
5352
| dampening_frac: 0.001
5453
| block_size: 128
5554
| config_groups:
@@ -67,8 +66,8 @@ class GPTQModifier(Modifier):
6766
| actorder: False
6867
6968
70-
:param sequential_update: Whether or not to update weights sequentially by layer,
71-
True saves on GPU memory, default is True
69+
:param sequential_update: Whether or not to update weights sequentially by layer.
70+
This option is depreciated and setting to False is no longer supported
7271
:param targets: list of layer names to compress during GPTQ, or '__ALL__'
7372
to compress every layer in the model
7473
:param block_size: Used to determine number of columns to compress in one pass
@@ -98,7 +97,7 @@ class GPTQModifier(Modifier):
9897
and activation 8 bit quantization on the Linear layers.
9998
"""
10099

101-
sequential_update: bool = True
100+
sequential_update: bool = True # DEPRECIATED
102101
targets: Union[str, List[str], None] = None
103102
sequential_targets: Union[str, List[str], None] = None
104103
block_size: int = 128
@@ -118,13 +117,13 @@ class GPTQModifier(Modifier):
118117
@field_validator("sequential_update", mode="before")
119118
def validate_sequential_update(cls, value: bool) -> bool:
120119
if not value:
121-
logger.warning(
122-
"Not using sequential_update requires allocating all hessians in "
123-
"GPU memory. If you are running into GPU memory issues, consider "
124-
"using sequential_update=True"
120+
warnings.warn(
121+
"`sequential_update=False` is no longer supported, setting "
122+
"sequential_update=True",
123+
DeprecationWarning,
125124
)
126125

127-
return value
126+
return True
128127

129128
def on_initialize_structure(self, state: State, **kwargs):
130129
"""
@@ -246,7 +245,7 @@ def initialize_compression(
246245
compressible layers of model, and sets the device
247246
248247
:param model: model to initialize for compression
249-
:param dataloader: calibration data for GPTQ
248+
:param dataloader: calibration data, not used by GPTQ in this function
250249
"""
251250
self.model = model
252251
self.compressible_layers_ = self.compressible_layers()
@@ -258,16 +257,12 @@ def initialize_compression(
258257
args = self._pruning_arguments()
259258
comp_cls = self._compression_class()
260259
compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args)
261-
262-
# if running sequentially, allocate all hessians now
263-
if not self.sequential_update:
264-
compressor.pre_compress()
265-
266260
self.layer_compressors_.append(compressor)
267261

268-
if self.sequential_update:
269-
first_layer_compressor = self.layer_compressors_[0]
270-
first_layer_compressor.set_early_stop()
262+
# for the initial forward data pass, add an early stop exception in order
263+
# to capture inputs right before being compressed by first module
264+
first_layer_compressor = self.layer_compressors_[0]
265+
first_layer_compressor.set_early_stop()
271266

272267
@torch.no_grad()
273268
def apply_compression(
@@ -288,45 +283,32 @@ def apply_compression(
288283
self.model.apply(disable_quantization)
289284

290285
with DisableKVCache(self.model):
291-
# in non-sequential mode we run calibration through the full model
292-
# in sequential mode we run calibration up to the first transformer target
286+
# run_calibration_forward uses the early stop exception to capture values
287+
# as intermediates right before the forward pass of the first module
293288
intermediates = run_calibration_forward(
294289
self.model, dataloader, mask_padding=True
295290
)
296291
self.layer_compressors_[0].clear_early_stop()
297292

298-
# empty cache if not using sequential update
299-
if not self.sequential_update:
300-
del intermediates
301-
gc.collect()
302-
torch.cuda.empty_cache()
303-
304293
num_layers = len(self.compressible_layers_)
305294
for idx, layer_compressor in enumerate(self.layer_compressors_):
306295
logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====")
307296

308-
if self.sequential_update:
309-
# in sequential mode we run the forward pass for each layer
310-
# one at a time, caching the intermediate outputs between layers
311-
logger.info(f"Calibrating {layer_compressor.name}...")
312-
layer_compressor.pre_compress()
313-
unquantized_outputs = layer_compressor.calibrate_layer(
314-
intermediates
315-
)
297+
# run the forward pass for each transformer layer (block) one at a time
298+
logger.info(f"Calibrating {layer_compressor.name}...")
299+
layer_compressor.pre_compress()
300+
unquantized_outputs = layer_compressor.calibrate_layer(intermediates)
316301

317302
layer_compressor.compress()
318303
layer_compressor.post_compress()
319304
layer_compressor.revert_layer_wrappers()
320305

321-
if self.sequential_update:
322-
quantized_outputs = layer_compressor.calibrate_layer(intermediates)
323-
error = get_output_error(unquantized_outputs, quantized_outputs)
324-
logger.info(f"Mean output error from quantization: {error:.3f}")
325-
intermediates = quantized_outputs
326-
del unquantized_outputs
327-
328-
gc.collect()
329-
torch.cuda.empty_cache()
306+
# perform a second forward pass of the module to calculate
307+
# weight-quantized outputs for use as inputs to the next layer
308+
quantized_outputs = layer_compressor.calibrate_layer(intermediates)
309+
error = get_output_error(unquantized_outputs, quantized_outputs)
310+
logger.info(f"Mean output error from quantization: {error:.3f}")
311+
intermediates = quantized_outputs
330312

331313
# re-enable quantization
332314
self.model.apply(enable_quantization)

0 commit comments

Comments
 (0)