2
2
from typing import Callable , Dict , List , Optional , Tuple
3
3
4
4
import torch
5
+ from compressed_tensors .utils .offload import is_module_offloaded
5
6
from loguru import logger
6
7
from torch .nn import Module
7
8
@@ -282,6 +283,10 @@ def _apply_smoothing(self, model: Module):
282
283
283
284
@torch .no_grad ()
284
285
def smooth (module ):
286
+ offloaded = is_module_offloaded (module )
287
+ if offloaded :
288
+ module ._hf_hook .pre_forward (module )
289
+
285
290
if module in balance_layers :
286
291
module .weight .mul_ (scales .view (1 , - 1 ))
287
292
elif module == smooth_layer :
@@ -292,6 +297,9 @@ def smooth(module):
292
297
if hasattr (module , "bias" ) and module .bias is not None :
293
298
module .bias .div_ (scales )
294
299
300
+ if offloaded :
301
+ module ._hf_hook .post_forward (module , None )
302
+
295
303
parent = get_fsdp_parent (mapping .smooth_name , model )
296
304
if parent is not None :
297
305
parent .apply (smooth )
@@ -318,8 +326,16 @@ def _calculate_smoothing_scales(
318
326
# get the channel-wise dynamic range for each layer to be balanced
319
327
weight_scales = []
320
328
for layer in balance_layers :
329
+ offloaded = is_module_offloaded (layer )
330
+ if offloaded :
331
+ layer ._hf_hook .pre_forward (layer )
332
+
321
333
scale = layer .weight .abs ().max (dim = 0 , keepdim = True )[0 ]
322
334
weight_scales .append (scale )
335
+
336
+ if offloaded :
337
+ layer ._hf_hook .post_forward (layer , None )
338
+
323
339
weight_scales = 2.0 * torch .cat (weight_scales , dim = 0 ).max (dim = 0 )[0 ]
324
340
325
341
# calculate the amount of smoothing to apply
@@ -329,4 +345,5 @@ def _calculate_smoothing_scales(
329
345
1 - self .smoothing_strength
330
346
)
331
347
scales = torch .where (weight_scales > 0.0 , scales , activation_scales )
348
+
332
349
return scales
0 commit comments