Skip to content

Commit 2443ba9

Browse files
Fix long contexts in LoRA (#624)
#566 breaks long-contexts + LoRA flow. This assumes caching sin-cos buffer for first decoder layer is sufficient to handle all cases, which is not the applicable for long-context + LoRA. This PR ignores `_prepare_cos_sin` call prior to HpuModelAdapter forward in long-context + LoRA flow.
1 parent 9555fef commit 2443ba9

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

tests/lora/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
6464

6565
@pytest.fixture
6666
def dist_init():
67+
import habana_frameworks.torch.hpu # noqa: F401
6768
temp_file = tempfile.mkstemp()[1]
6869
backend_type = "hccl" if current_platform.is_hpu() else "nccl"
6970
init_distributed_environment(

vllm/lora/punica_wrapper/utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import torch
44

5+
from vllm.platforms import current_platform
6+
57
if TYPE_CHECKING:
68
# avoid circuit import
79
from vllm.lora.layers import LoRAMapping
@@ -86,10 +88,14 @@ def convert_mapping(
8688
embedding_indices = index_mapping_indices.copy()
8789
lora_indices = index_mapping_indices.copy()
8890
long_lora_offsets: Optional[torch.Tensor] = None
91+
8992
if long_lora_context:
90-
long_lora_offsets = torch.zeros(len(index_mapping_indices),
91-
device=device,
92-
dtype=torch.long)
93+
if current_platform.is_hpu():
94+
long_lora_offsets_list: List[int] = []
95+
else:
96+
long_lora_offsets = torch.zeros(len(index_mapping_indices),
97+
device=device,
98+
dtype=torch.long)
9399
prompt_mapping: List[int] = [
94100
lora_index_to_id.index(x) if x > 0 else -1
95101
for x in mapping.prompt_mapping
@@ -102,10 +108,18 @@ def convert_mapping(
102108
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
103109
lora_indices[i] = lora_idx
104110
if long_lora_context:
105-
assert long_lora_offsets is not None
106111
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
107112
index_mapping_indices[i], 0)
108-
long_lora_offsets[i] = lora_offset
113+
if current_platform.is_hpu():
114+
long_lora_offsets_list.append(lora_offset)
115+
else:
116+
assert long_lora_offsets is not None
117+
long_lora_offsets[i] = lora_offset
118+
119+
if long_lora_context and current_platform.is_hpu():
120+
long_lora_offsets = torch.tensor(long_lora_offsets_list,
121+
device=device,
122+
dtype=torch.long)
109123

110124
indices_list: List[Union[List[int], torch.Tensor]] = [
111125
index_mapping_indices,

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,10 @@ def forward_hpu(
232232
) -> Tuple[torch.Tensor, torch.Tensor]:
233233
from habana_frameworks.torch.hpex.kernels import (
234234
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
235-
if self.sin is None:
235+
236+
# Prepare cos-sin caches for long-context + LoRA with offsets for every
237+
# forward, since the offset information wasn't available previously
238+
if hasattr(self, "scaling_factors") or self.sin is None:
236239
self.prepare_cos_sin(positions, offsets)
237240
num_tokens = positions.shape[0] * positions.shape[1]
238241
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal

0 commit comments

Comments
 (0)