Skip to content

Commit 1f194e5

Browse files
rename to _smooth_activation_means
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 250da34 commit 1f194e5

File tree

1 file changed

+8
-8
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+8
-8
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class AWQModifier(Modifier, QuantizationMixin):
141141
default_factory=dict
142142
)
143143
# Dict[smooth layer name, (activation means, activation counts)]
144-
_smooth_activation_cache: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr(
144+
_smooth_activation_means: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr(
145145
default_factory=dict
146146
)
147147

@@ -289,7 +289,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:
289289
self.on_end(state, None)
290290

291291
self._parent_args_cache.clear()
292-
self._smooth_activation_cache.clear()
292+
self._smooth_activation_means.clear()
293293
self._resolved_mappings.clear()
294294

295295
return True
@@ -404,9 +404,9 @@ def cache_smooth_activations_hook(
404404
# Assume that first argument is the input
405405
inp = args[0].cpu().detach().squeeze()
406406

407-
self._smooth_activation_cache[smooth_name] = _accumulate_mean(
407+
self._smooth_activation_means[smooth_name] = _accumulate_mean(
408408
inp,
409-
self._smooth_activation_cache.get(smooth_name, None),
409+
self._smooth_activation_means.get(smooth_name, None),
410410
)
411411

412412
return cache_smooth_activations_hook
@@ -447,7 +447,7 @@ def _apply_smoothing(self, model: Module) -> None:
447447
for mapping in tqdm(self._resolved_mappings, desc="Smoothing"):
448448
# NOTE: When using SequentialPipeline, not all the mappings
449449
# will have cached activations in the segment being udpated
450-
if mapping.smooth_name not in self._smooth_activation_cache:
450+
if mapping.smooth_name not in self._smooth_activation_means:
451451
continue
452452

453453
smooth_layer = mapping.smooth_layer
@@ -477,7 +477,7 @@ def _apply_smoothing(self, model: Module) -> None:
477477
torch.finfo(fp16_output.dtype).min,
478478
torch.finfo(fp16_output.dtype).max,
479479
)
480-
x_mean = self._smooth_activation_cache[mapping.smooth_name][0]
480+
x_mean = self._smooth_activation_means[mapping.smooth_name][0]
481481

482482
# [STEP 4]: Compute loss
483483
best_scales = self._compute_best_scale(
@@ -528,7 +528,7 @@ def smooth(module):
528528
smooth(smooth_layer)
529529

530530
# remove caches needed to smooth this mapping
531-
del self._smooth_activation_cache[mapping.smooth_name]
531+
del self._smooth_activation_means[mapping.smooth_name]
532532

533533
for v in self._parent_args_cache.values():
534534
v.batch_intermediates.clear()
@@ -674,7 +674,7 @@ def _assert_all_activations_consumed(self):
674674
Confirm all activations have been consumed
675675
If not, something has gone wrong
676676
"""
677-
if len(self._smooth_activation_cache) != 0:
677+
if len(self._smooth_activation_means) != 0:
678678
raise RuntimeError("Some cached activations were not used")
679679

680680

0 commit comments

Comments
 (0)