Skip to content

Commit 93e6020

Browse files
committed
account for slightly different update param behavior (#1005)
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 612ebfd commit 93e6020

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ModelCompressor,
1212
SparsityCompressionConfig,
1313
is_module_offloaded,
14-
update_parameter_data,
14+
update_offload_parameter,
1515
)
1616
from loguru import logger
1717
from safetensors.torch import storage_ptr
@@ -238,14 +238,15 @@ def patch_tied_tensors_bug(model: torch.nn.Module):
238238

239239
if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight):
240240
for module in (input_embed, output_embed):
241-
offloaded = is_module_offloaded(module)
242-
if offloaded:
243-
module._hf_hook.pre_forward(module)
244-
245-
update_parameter_data(module, module.weight.data.clone(), "weight")
246-
247-
if offloaded:
248-
module._hf_hook.post_forward(module, None)
241+
if not is_module_offloaded(module):
242+
# create new storage ptr for onloaded weight
243+
untied_data = module.weight.data.clone()
244+
module.weight.data = untied_data
245+
else:
246+
# create new storage ptr for offloaded weight
247+
# note `update_offload_parameter` does not create a new storage ptr
248+
untied_data = module._hf_hook.weights_map["weight"].clone()
249+
update_offload_parameter(module, "weight", untied_data)
249250

250251

251252
def get_model_compressor(

0 commit comments

Comments
 (0)