File tree Expand file tree Collapse file tree 1 file changed +10
-9
lines changed
src/llmcompressor/transformers/sparsification Expand file tree Collapse file tree 1 file changed +10
-9
lines changed Original file line number Diff line number Diff line change 11
11
ModelCompressor ,
12
12
SparsityCompressionConfig ,
13
13
is_module_offloaded ,
14
- update_parameter_data ,
14
+ update_offload_parameter ,
15
15
)
16
16
from loguru import logger
17
17
from safetensors .torch import storage_ptr
@@ -238,14 +238,15 @@ def patch_tied_tensors_bug(model: torch.nn.Module):
238
238
239
239
if storage_ptr (input_embed .weight ) == storage_ptr (output_embed .weight ):
240
240
for module in (input_embed , output_embed ):
241
- offloaded = is_module_offloaded (module )
242
- if offloaded :
243
- module ._hf_hook .pre_forward (module )
244
-
245
- update_parameter_data (module , module .weight .data .clone (), "weight" )
246
-
247
- if offloaded :
248
- module ._hf_hook .post_forward (module , None )
241
+ if not is_module_offloaded (module ):
242
+ # create new storage ptr for onloaded weight
243
+ untied_data = module .weight .data .clone ()
244
+ module .weight .data = untied_data
245
+ else :
246
+ # create new storage ptr for offloaded weight
247
+ # note `update_offload_parameter` does not create a new storage ptr
248
+ untied_data = module ._hf_hook .weights_map ["weight" ].clone ()
249
+ update_offload_parameter (module , "weight" , untied_data )
249
250
250
251
251
252
def get_model_compressor (
You can’t perform that action at this time.
0 commit comments