@@ -141,7 +141,7 @@ class AWQModifier(Modifier, QuantizationMixin):
141
141
default_factory = dict
142
142
)
143
143
# Dict[smooth layer name, (activation means, activation counts)]
144
- _smooth_activation_cache : Dict [str , Tuple [torch .FloatTensor , int ]] = PrivateAttr (
144
+ _smooth_activation_means : Dict [str , Tuple [torch .FloatTensor , int ]] = PrivateAttr (
145
145
default_factory = dict
146
146
)
147
147
@@ -289,7 +289,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:
289
289
self .on_end (state , None )
290
290
291
291
self ._parent_args_cache .clear ()
292
- self ._smooth_activation_cache .clear ()
292
+ self ._smooth_activation_means .clear ()
293
293
self ._resolved_mappings .clear ()
294
294
295
295
return True
@@ -404,9 +404,9 @@ def cache_smooth_activations_hook(
404
404
# Assume that first argument is the input
405
405
inp = args [0 ].cpu ().detach ().squeeze ()
406
406
407
- self ._smooth_activation_cache [smooth_name ] = _accumulate_mean (
407
+ self ._smooth_activation_means [smooth_name ] = _accumulate_mean (
408
408
inp ,
409
- self ._smooth_activation_cache .get (smooth_name , None ),
409
+ self ._smooth_activation_means .get (smooth_name , None ),
410
410
)
411
411
412
412
return cache_smooth_activations_hook
@@ -447,7 +447,7 @@ def _apply_smoothing(self, model: Module) -> None:
447
447
for mapping in tqdm (self ._resolved_mappings , desc = "Smoothing" ):
448
448
# NOTE: When using SequentialPipeline, not all the mappings
449
449
# will have cached activations in the segment being udpated
450
- if mapping .smooth_name not in self ._smooth_activation_cache :
450
+ if mapping .smooth_name not in self ._smooth_activation_means :
451
451
continue
452
452
453
453
smooth_layer = mapping .smooth_layer
@@ -477,7 +477,7 @@ def _apply_smoothing(self, model: Module) -> None:
477
477
torch .finfo (fp16_output .dtype ).min ,
478
478
torch .finfo (fp16_output .dtype ).max ,
479
479
)
480
- x_mean = self ._smooth_activation_cache [mapping .smooth_name ][0 ]
480
+ x_mean = self ._smooth_activation_means [mapping .smooth_name ][0 ]
481
481
482
482
# [STEP 4]: Compute loss
483
483
best_scales = self ._compute_best_scale (
@@ -528,7 +528,7 @@ def smooth(module):
528
528
smooth (smooth_layer )
529
529
530
530
# remove caches needed to smooth this mapping
531
- del self ._smooth_activation_cache [mapping .smooth_name ]
531
+ del self ._smooth_activation_means [mapping .smooth_name ]
532
532
533
533
for v in self ._parent_args_cache .values ():
534
534
v .batch_intermediates .clear ()
@@ -674,7 +674,7 @@ def _assert_all_activations_consumed(self):
674
674
Confirm all activations have been consumed
675
675
If not, something has gone wrong
676
676
"""
677
- if len (self ._smooth_activation_cache ) != 0 :
677
+ if len (self ._smooth_activation_means ) != 0 :
678
678
raise RuntimeError ("Some cached activations were not used" )
679
679
680
680
0 commit comments