Skip to content

Commit beb1b20

Browse files
committed
fix: make --load-8bit flag work with weights in safetensors format
1 parent e53c73f commit beb1b20

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

fastchat/model/compression.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,27 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
168168
base_pattern = os.path.join(model_path, "pytorch_model*.bin")
169169

170170
files = glob.glob(base_pattern)
171+
use_safetensors = False
172+
if len(files) == 0:
173+
base_pattern = os.path.join(model_path, "*.safetensors")
174+
files = glob.glob(base_pattern)
175+
use_safetensors = True
171176
if len(files) == 0:
172177
raise ValueError(
173178
f"Cannot find any model weight files. "
174179
f"Please check your (cached) weight path: {model_path}"
175180
)
176181

177182
compressed_state_dict = {}
183+
if use_safetensors:
184+
from safetensors.torch import load_file
178185
for filename in tqdm(files):
179-
tmp_state_dict = torch.load(filename, map_location=lambda storage, loc: storage)
186+
if use_safetensors:
187+
tmp_state_dict = load_file(filename)
188+
else:
189+
tmp_state_dict = torch.load(
190+
filename, map_location=lambda storage, loc: storage
191+
)
180192
for name in tmp_state_dict:
181193
if name in linear_weights:
182194
tensor = tmp_state_dict[name].to(device, dtype=torch_dtype)

0 commit comments

Comments
 (0)