Skip to content

Commit d3e75be

Browse files
dsikkahorheynm
authored andcommitted
Fix SmoothQuant offload bug (#978)
* fix offload Signed-off-by: Dipika <[email protected]> * fix smoothquant offload bug * remove logtime --------- Signed-off-by: Dipika <[email protected]>
1 parent 21a0d24 commit d3e75be

File tree

1 file changed

+17
-0
lines changed
  • src/llmcompressor/modifiers/smoothquant

1 file changed

+17
-0
lines changed

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Callable, Dict, List, Optional, Tuple
33

44
import torch
5+
from compressed_tensors.utils.offload import is_module_offloaded
56
from loguru import logger
67
from torch.nn import Module
78

@@ -282,6 +283,10 @@ def _apply_smoothing(self, model: Module):
282283

283284
@torch.no_grad()
284285
def smooth(module):
286+
offloaded = is_module_offloaded(module)
287+
if offloaded:
288+
module._hf_hook.pre_forward(module)
289+
285290
if module in balance_layers:
286291
module.weight.mul_(scales.view(1, -1))
287292
elif module == smooth_layer:
@@ -292,6 +297,9 @@ def smooth(module):
292297
if hasattr(module, "bias") and module.bias is not None:
293298
module.bias.div_(scales)
294299

300+
if offloaded:
301+
module._hf_hook.post_forward(module, None)
302+
295303
parent = get_fsdp_parent(mapping.smooth_name, model)
296304
if parent is not None:
297305
parent.apply(smooth)
@@ -318,8 +326,16 @@ def _calculate_smoothing_scales(
318326
# get the channel-wise dynamic range for each layer to be balanced
319327
weight_scales = []
320328
for layer in balance_layers:
329+
offloaded = is_module_offloaded(layer)
330+
if offloaded:
331+
layer._hf_hook.pre_forward(layer)
332+
321333
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
322334
weight_scales.append(scale)
335+
336+
if offloaded:
337+
layer._hf_hook.post_forward(layer, None)
338+
323339
weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]
324340

325341
# calculate the amount of smoothing to apply
@@ -329,4 +345,5 @@ def _calculate_smoothing_scales(
329345
1 - self.smoothing_strength
330346
)
331347
scales = torch.where(weight_scales > 0.0, scales, activation_scales)
348+
332349
return scales

0 commit comments

Comments
 (0)