Skip to content

Commit b812f0d

Browse files
jeejeeleeshreyankg
authored andcommitted
[Bugfix] Fix JambaForCausalLM LoRA (vllm-project#14370)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent d54e2b6 commit b812f0d

File tree

4 files changed

+35
-83
lines changed

4 files changed

+35
-83
lines changed

tests/lora/conftest.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from unittest.mock import MagicMock, patch
77

88
import pytest
9-
import safetensors
109
import torch
1110
import torch.nn as nn
1211
from huggingface_hub import snapshot_download
@@ -191,29 +190,6 @@ def mixtral_lora_files_all_target_modules():
191190
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
192191

193192

194-
@pytest.fixture(scope="session")
195-
def jamba_lora_files():
196-
# some of the adapters have unnecessary weights for serving,
197-
# hence we remove them
198-
def remove_unnecessary_weights(path):
199-
lora_path = f"{adapter_path}/adapter_model.safetensors"
200-
tensors = safetensors.torch.load_file(lora_path)
201-
nonlora_keys = []
202-
for k in list(tensors.keys()):
203-
if "lora" not in k:
204-
nonlora_keys.append(k)
205-
for k in nonlora_keys:
206-
del tensors[k]
207-
safetensors.torch.save_file(tensors, lora_path)
208-
209-
adapter_path = snapshot_download(
210-
repo_id=
211-
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
212-
213-
remove_unnecessary_weights(adapter_path)
214-
return adapter_path
215-
216-
217193
@pytest.fixture(scope="session")
218194
def gemma_lora_files():
219195
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")

tests/lora/test_jamba.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

tests/lora/test_layers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ def create_random_linear_replicated_layer():
632632

633633
id_to_index = get_random_id_to_index(num_loras, max_loras)
634634
linear, lora_linear = create_random_linear_replicated_layer()
635+
assert torch.equal(linear.weight, lora_linear.weight)
635636
lora_linear.set_mapping(punica_wrapper)
636637
lora_dict, _ = populate_loras(
637638
id_to_index,
@@ -757,6 +758,7 @@ def create_random_linear_parallel_layer():
757758

758759
id_to_index = get_random_id_to_index(num_loras, max_loras)
759760
linear, lora_linear = create_random_linear_parallel_layer()
761+
assert torch.equal(linear.weight, lora_linear.weight)
760762
lora_linear.set_mapping(punica_wrapper)
761763
lora_dict, _ = populate_loras(
762764
id_to_index,
@@ -904,6 +906,7 @@ class FakeConfig:
904906
id_to_index = get_random_id_to_index(num_loras, max_loras)
905907

906908
linear, lora_linear = create_column_parallel_packed_layer()
909+
assert torch.equal(linear.weight, lora_linear.weight)
907910
lora_linear.set_mapping(punica_wrapper)
908911
lora_dict, sublora_dict = populate_loras(
909912
id_to_index,

vllm/lora/layers.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ def can_replace_layer(
274274
) -> bool:
275275
return type(source_layer) is VocabParallelEmbedding
276276

277+
@property
278+
def weight(self):
279+
return self.base_layer.weight
280+
277281

278282
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
279283

@@ -409,6 +413,34 @@ def apply(self,
409413
self.output_slices)
410414
return output
411415

416+
@property
417+
def weight(self) -> torch.Tensor:
418+
419+
# unquantizedLinear
420+
if hasattr(self.base_layer, "weight"):
421+
return self.base_layer.weight
422+
# Compressed Tensor
423+
elif hasattr(self.base_layer, "weight_packed"):
424+
return self.base_layer.weight_packed
425+
# GPTQ/AWQ
426+
elif hasattr(self.base_layer, "qweight"):
427+
return self.base_layer.qweight
428+
# marlin
429+
elif hasattr(self.base_layer, "B"):
430+
return self.base_layer.B
431+
# HQQ marlin
432+
elif hasattr(self.base_layer, "W_q"):
433+
return self.base_layer.W_q
434+
else:
435+
raise ValueError(f"Unsupported base layer: {self.base_layer}")
436+
437+
@property
438+
def bias(self) -> Optional[torch.Tensor]:
439+
if hasattr(self.base_layer, "bias"):
440+
return self.base_layer.bias
441+
else:
442+
return None
443+
412444

413445
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
414446

@@ -902,11 +934,6 @@ def forward(
902934

903935
return output, output_bias
904936

905-
@property
906-
def weight(self):
907-
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
908-
else self.base_layer.qweight)
909-
910937
@classmethod
911938
@_not_fully_sharded_can_replace
912939
def can_replace_layer(

0 commit comments

Comments
 (0)