Skip to content

[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend #14238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 151 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
151 commits
Select commit Hold shift + click to select a range
d993de9
Added non-triton SGMV and BGMV ops (not kernels yet)
Akshat-Tripathi Nov 20, 2024
4f816ed
Made a copy of the layer tests for the TPU. TODO: DRY it out
Akshat-Tripathi Nov 20, 2024
5f0355b
Removed extra print
Akshat-Tripathi Nov 21, 2024
edd02c5
Made some minor shape-based fixes to the kernels
Akshat-Tripathi Nov 22, 2024
aff94f9
Added basic lora execution code
Akshat-Tripathi Nov 22, 2024
adfd194
Replaced einsums with matmuls+reshaping for better xla compilation
Akshat-Tripathi Nov 25, 2024
816a56c
Replaced inf/-inf with max/min since XLA doesn't allow `nan_to_num_()…
Akshat-Tripathi Nov 25, 2024
c8a51c8
Added lora config to `_dummy_run()`
Akshat-Tripathi Nov 25, 2024
51f929d
Changed torch._dynamo config
Akshat-Tripathi Nov 25, 2024
23d4a24
Quick patch to allow non lora code to run
Akshat-Tripathi Nov 25, 2024
47397a7
Minor fixes
Akshat-Tripathi Jan 17, 2025
456eb37
Replaced einsums with matmuls to allow xla compilation
Akshat-Tripathi Jan 22, 2025
eabc748
Removed xla ops for torch ops
Akshat-Tripathi Jan 23, 2025
ac9753e
Removed old debug log points
Akshat-Tripathi Jan 23, 2025
aa8b0fd
Fixed bgmv/sgmv shape error
Akshat-Tripathi Jan 23, 2025
124215f
Fixed lora batching crash in warmup
Akshat-Tripathi Jan 23, 2025
e148254
Fixed shape issue in add_lora_linear()
Akshat-Tripathi Jan 23, 2025
494b35e
Fixed dynamic lora tensor shapes
Akshat-Tripathi Jan 23, 2025
1dbfcd9
Fixed lora_input preparation for actual execution
Akshat-Tripathi Jan 23, 2025
1bb2578
Fixed wrong model bug
Akshat-Tripathi Jan 24, 2025
ddc4cbc
Moved if statements outside of for loops in PunicaWrapperTPU
Akshat-Tripathi Jan 24, 2025
48a6944
Added early exits to PunicaWrapperTPU lora functions
Akshat-Tripathi Jan 28, 2025
7802e84
Added torch ops for tpu (Static prefill sizes)
Akshat-Tripathi Jan 30, 2025
ab5396b
XLA bgmv operations are now imported from the default torch_ops
Akshat-Tripathi Jan 30, 2025
fdf29d3
Removed TODOs
Akshat-Tripathi Jan 31, 2025
c2b4139
Removed old code
Akshat-Tripathi Jan 31, 2025
f31b7d1
Linting
Akshat-Tripathi Jan 31, 2025
87ff73e
Fixed import error
Akshat-Tripathi Feb 3, 2025
96c3dde
lint
Akshat-Tripathi Feb 4, 2025
4e72ede
Abstracted out infinity values
Akshat-Tripathi Mar 3, 2025
e4d35ce
Moved and modified bgmv ops from the cpu backend to the tpu backend, …
Akshat-Tripathi Feb 7, 2025
3cf0680
Removed total_size for linting
Akshat-Tripathi Feb 7, 2025
a8ab0c9
Reverted changes to torch_ops
Akshat-Tripathi Feb 7, 2025
d73f1ce
Lint
Akshat-Tripathi Feb 7, 2025
e01d9a4
Replaced in-place buffer updates with direct returning
Akshat-Tripathi Mar 3, 2025
0c1bfb9
PunicaWrapperTPU now returns unchanged buffer if no loras are needed
Akshat-Tripathi Feb 11, 2025
46ce7fa
Simplified TPU prefill
Akshat-Tripathi Feb 12, 2025
5d0cc37
Removed sgmv kernels from TPU implementation
Akshat-Tripathi Feb 12, 2025
7590b0e
Fix bug
Akshat-Tripathi Feb 12, 2025
e7f75b5
Added torch.compiles to PunicaWrapperTPU functions
Akshat-Tripathi Feb 12, 2025
fe193f7
Replaced "x[x==-1] = y" with "x = torch.where(x == - 1, y)"
Akshat-Tripathi Feb 14, 2025
52e3911
Revert "Added torch.compiles to PunicaWrapperTPU functions"
Akshat-Tripathi Feb 14, 2025
33a70b0
Fix linting
Akshat-Tripathi Feb 14, 2025
67446b2
Added lora hotswapping test
Akshat-Tripathi Feb 18, 2025
0db19b1
Fixed hotswapping test prompt
Akshat-Tripathi Feb 18, 2025
a4c3b0a
Fixed bug in tpu lora test
Akshat-Tripathi Feb 18, 2025
9d6c388
Merged set_no_lora() functionality with _udpate_prefill_metada
Akshat-Tripathi Feb 14, 2025
2a9978e
Added Multi-LoRA functionality to TPU V1
Akshat-Tripathi Feb 14, 2025
b8c65bc
Added test that verifies switching
Akshat-Tripathi Feb 17, 2025
942ef07
Added bgmv kernel test code
Akshat-Tripathi Feb 4, 2025
56529b9
Added some dynamic lora selection
Akshat-Tripathi Feb 6, 2025
735073f
Moved and modified bgmv ops from the cpu backend to the tpu backend, …
Akshat-Tripathi Feb 7, 2025
1067b50
Added bgmv kernel test
Akshat-Tripathi Feb 10, 2025
d897f87
Made bgmv kernel fully functional (WIP on supporting smaller ranks) (…
Akshat-Tripathi Feb 10, 2025
d6eca29
Updated bgmv_kernel to work with ranks that aren't exact multiples of…
Akshat-Tripathi Feb 17, 2025
d97aae5
Removed interpreted mode on kernel
Akshat-Tripathi Feb 18, 2025
3ac0f63
Added pallas kernel benchmarking script
Akshat-Tripathi Feb 18, 2025
a620e58
Fixed mosaic kernel compilation issue
Akshat-Tripathi Feb 24, 2025
00d6dfd
Added reference kernel benchmarking
Akshat-Tripathi Feb 24, 2025
fb0601d
Registered the custom op
Akshat-Tripathi Feb 24, 2025
89b062e
Integrated bgmv kernel
Akshat-Tripathi Feb 24, 2025
ef2ef8c
Fixed model compilation bugs
Akshat-Tripathi Feb 24, 2025
a79e19d
Minor changes
Akshat-Tripathi Feb 25, 2025
cc8cdf6
Removed scratch files
Akshat-Tripathi Mar 4, 2025
ad8c565
Minor pallas kernel fixes
Akshat-Tripathi Mar 5, 2025
8d83065
integrate ragged paged attn v2
yaochengji Mar 3, 2025
dea7d02
fix precompile
yaochengji Mar 5, 2025
0cf0eaa
Merge branch 'chengji/ragged_attn_v2_new' into multi_lora_tpu_v1
Akshat-Tripathi Mar 6, 2025
6249307
Fixed padding issue with v1
Akshat-Tripathi Mar 6, 2025
af0a6a9
Added temporary patch over pallas kernel routing bug
Akshat-Tripathi Mar 6, 2025
264d36a
Updated kernel test
Akshat-Tripathi Mar 6, 2025
b725c6a
Lint
Akshat-Tripathi Mar 6, 2025
038465c
Removed duplicate method
Akshat-Tripathi Mar 6, 2025
2004369
Lint
Akshat-Tripathi Mar 6, 2025
71a1cdd
More linting
Akshat-Tripathi Mar 6, 2025
3dba9e0
Linting
Akshat-Tripathi Mar 6, 2025
f7f95e4
Lint
Akshat-Tripathi Mar 6, 2025
adfdcdb
Fixed bug related to consecutive pallas kernels
Akshat-Tripathi Mar 6, 2025
a6d5c01
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 7, 2025
5a27785
Removed v0 TPU LoRA implementation
Akshat-Tripathi Mar 7, 2025
5d15fbc
Fixed VocabParallelEmbeddingWithLoRA compilation error
Akshat-Tripathi Mar 8, 2025
ca3d810
Fixed LogitsProcessorWithLoRA layer compilation issue
Akshat-Tripathi Mar 10, 2025
12f71ce
Slightly sped up the kernel
Akshat-Tripathi Mar 10, 2025
d040ee8
Lint
Akshat-Tripathi Mar 10, 2025
e696144
Fixed bug with higher batch sizes
Akshat-Tripathi Mar 10, 2025
d110613
Lint
Akshat-Tripathi Mar 10, 2025
f8d5da2
Removed TODO in bgmv pallas test
Akshat-Tripathi Mar 11, 2025
d114377
Fixed PunicaWrapperBase typing
Akshat-Tripathi Mar 11, 2025
430bae9
Fixed bug where vLLM crashes on decode
Akshat-Tripathi Mar 11, 2025
fb36fd6
Fixed NaN bug with LogitsProcessor
Akshat-Tripathi Mar 11, 2025
c454062
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 12, 2025
23b14d1
Updated LoRALogitsProcessor to work with the TPU
Akshat-Tripathi Mar 12, 2025
27d6f70
Lint
Akshat-Tripathi Mar 12, 2025
b547271
Fixed batched logits processing
Akshat-Tripathi Mar 12, 2025
1bb152f
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 18, 2025
af15bd1
Added comment
Akshat-Tripathi Mar 18, 2025
41555d1
Lint
Akshat-Tripathi Mar 18, 2025
640420b
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 20, 2025
a02d0e9
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 24, 2025
e07d6fb
Moved punica related `mark_dynamic` to the TPUModelRunner to allow th…
Akshat-Tripathi Mar 24, 2025
5b4ba1b
Moved `maybe_dummy_run_with_lora` to the `_dummy_run` method
Akshat-Tripathi Mar 24, 2025
49a8102
Minor fixes + lint
Akshat-Tripathi Mar 24, 2025
c1be5f9
Lint
Akshat-Tripathi Mar 24, 2025
15ff074
Fixed mark_dynamic placement for eager/compiled modes
Akshat-Tripathi Mar 25, 2025
ab036e0
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 26, 2025
b6af323
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 26, 2025
8ba2749
Added error for when someone tries to use LoRA adapters on the V0 TPU…
Akshat-Tripathi Mar 27, 2025
51d87a5
Added test to buildkite
Akshat-Tripathi Mar 27, 2025
bf52dbd
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 27, 2025
8b1dae8
Lint
Akshat-Tripathi Mar 27, 2025
151fde4
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Mar 31, 2025
8a3009d
Added type annotation to lora_output
Akshat-Tripathi Mar 31, 2025
9fb50b9
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 2, 2025
eb72ab6
Removed LoRA vocab padding for TPU
Akshat-Tripathi Apr 4, 2025
c8f68d7
Changed TPU lora_vocab_padding_size to 1
Akshat-Tripathi Apr 4, 2025
ed3b245
Enabled lora bias
Akshat-Tripathi Apr 4, 2025
54c00c3
Enabled fully sharded loras
Akshat-Tripathi Apr 7, 2025
9f0fdbe
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 7, 2025
2012bbd
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 9, 2025
1803135
Removed tuple return in add_shrink()
Akshat-Tripathi Apr 9, 2025
342ff8b
Fix pre-commit
Akshat-Tripathi Apr 10, 2025
fc65edb
Reduced number of iterations in test_lora
Akshat-Tripathi Apr 10, 2025
2f1da29
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 10, 2025
7daaafa
Lint
Akshat-Tripathi Apr 10, 2025
893ac04
Reduced pallas kernel test size
Akshat-Tripathi Apr 11, 2025
2a0fce7
Added/removed comments
Akshat-Tripathi Apr 11, 2025
4d42844
Fixed pallas kernel test
Akshat-Tripathi Apr 11, 2025
50a06fc
Made LoRA e2e test more robust
Akshat-Tripathi Apr 11, 2025
ca68ce6
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 22, 2025
f4be6cc
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 22, 2025
155c2ad
Merge branch 'multi_lora_tpu_v0' of https://github.com/krai/vllm into…
Akshat-Tripathi Apr 22, 2025
317a131
Removed mark_compiled from punica_tpu
Akshat-Tripathi Apr 22, 2025
b482ec8
Split TPU LoRA test into several smaller ones
Akshat-Tripathi Apr 22, 2025
2f26dd9
Fix lora spelling
Akshat-Tripathi Apr 22, 2025
8ccbaa8
Added comment explaining how multi-lora test adapters were trained
Akshat-Tripathi Apr 24, 2025
d227381
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 24, 2025
b65f60e
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 25, 2025
8a45758
Moved TPU lora tests into tests/tpu/lora
Akshat-Tripathi Apr 25, 2025
987589a
Updated TPU tests
Akshat-Tripathi Apr 25, 2025
bc49d0f
Fixed tpu-test script
Akshat-Tripathi Apr 25, 2025
50e9738
Fixed pallas kernel dtype in test
Akshat-Tripathi Apr 25, 2025
4a07cf6
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 29, 2025
8cd5cb7
Disabled LoRA serving for now
Akshat-Tripathi Apr 29, 2025
6282cd5
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi Apr 30, 2025
1846ef3
Temporarily disabled the TPU lora tests
Akshat-Tripathi Apr 30, 2025
a006f6b
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi May 1, 2025
d72a86b
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi May 6, 2025
aff7414
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi May 6, 2025
e487ecb
Fixed incorrect torch.wheres
Akshat-Tripathi May 7, 2025
20c5981
Merge branch 'main' into multi_lora_tpu_v1
Akshat-Tripathi May 7, 2025
df67053
Lint
Akshat-Tripathi May 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \
&& echo TEST_12 \
&& pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \
# Disable the TPU LoRA tests until the feature is activated
# && echo TEST_13 \
# && pytest -s -v /workspace/vllm/tests/tpu/lora/" \


# TODO: This test fails because it uses RANDOM_SEED sampling
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def dist_init():
temp_file = tempfile.mkstemp()[1]

backend = "nccl"
if current_platform.is_cpu():
if current_platform.is_cpu() or current_platform.is_tpu():
backend = "gloo"

init_distributed_environment(world_size=1,
Expand Down
Empty file added tests/tpu/lora/__init__.py
Empty file.
124 changes: 124 additions & 0 deletions tests/tpu/lora/test_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

import vllm
from vllm.lora.request import LoRARequest

# This file contains tests to ensure that LoRA works correctly on the TPU
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
# for this. The adapters are:
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
# from 1 to 4.

# These adapters are trained using a standard huggingface peft training script,
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
# 100 training iterations with a training batch size of 100.


@pytest.fixture(scope="function", autouse=True)
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
"""
Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
for all tests in this file
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
yield


def setup_vllm(num_loras: int) -> vllm.LLM:
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
num_scheduler_steps=1,
max_model_len=256,
max_seq_len_to_capture=256,
max_num_seqs=8,
enable_lora=True,
max_loras=num_loras,
max_lora_rank=8)


def test_single_lora():
"""
This test ensures we can run a single LoRA adapter on the TPU backend.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=1.
"""

llm = setup_vllm(1)

prompt = "What is 1+1? \n"

lora_request = LoRARequest(
"lora_adapter_1", 1,
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter")
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(max_tokens=256,
temperature=0),
lora_request=lora_request)[0].outputs[0].text

answer = output.strip()[0]

assert answer.isdigit()
assert int(answer) == 1


def test_lora_hotswapping():
"""
This test ensures we can run multiple LoRA adapters on the TPU backend, even
if we only have space to store 1.

We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
"""

lora_name_template = \
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
lora_requests = [
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
for i in range(1, 5)
]

llm = setup_vllm(1)

prompt = "What is 1+1? \n"

for i, req in enumerate(lora_requests):
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(
max_tokens=256, temperature=0),
lora_request=req)[0].outputs[0].text
answer = output.strip()[0]

assert answer.isdigit()
assert int(answer) == i + 1


def test_multi_lora():
"""
This test ensures we can run multiple LoRA adapters on the TPU backend, when
we have enough space to store all of them.

We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
"""
lora_name_template = \
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
lora_requests = [
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
for i in range(1, 5)
]

llm = setup_vllm(4)

prompt = "What is 1+1? \n"

for i, req in enumerate(lora_requests):
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(
max_tokens=256, temperature=0),
lora_request=req)[0].outputs[0].text

answer = output.strip()[0]

assert answer.isdigit()
assert int(output.strip()[0]) == i + 1
73 changes: 73 additions & 0 deletions tests/tpu/lora/test_pallas_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch

# Required to register the custom ops
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import

N_TOKENS = [16, 1024, 4096]
HIDDEN_SIZES = [1024, 2048, 4096]

DTYPES = [torch.bfloat16]
NUM_LORA = [1, 4, 16]
RANKS = [32, 256, 512]


def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
"""
Inputs: (All integers)
T: Total number of tokens
D: Input dim
L: LoRA Dim
N: N LoRAs

Outputs:
inputs: torch.Tensor - shape (T, D)
loras: torch.Tensor - shape (N, 1, L, D)
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)

ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
"""
torch.manual_seed(seed)

inputs = torch.randn((T, D), device="xla", dtype=dtype)
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")

ref_output = ref_bgmv(inputs, loras, idxs)
return inputs, loras, idxs, ref_output


def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
selected_loras = loras[idxs]
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(axis=1)

batch_size, output_size, input_size = selected_loras.shape
return (selected_loras @ inputs.reshape(
(batch_size, input_size, 1))).reshape((batch_size, output_size))


# Parameterize tests with various shapes and dtypes
@pytest.mark.parametrize("T", N_TOKENS)
@pytest.mark.parametrize("D", HIDDEN_SIZES)
@pytest.mark.parametrize("L", RANKS)
@pytest.mark.parametrize("N", NUM_LORA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", [0])
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
if op_type == "expand":
D, L = L, D

inputs, loras, idxs, ref_output = generate_test_data(
T, D, L, N, seed, dtype)

# Run bgmv
output = torch.ops.xla.bgmv(inputs, loras, idxs)

# Make sure we have no NaNs
assert not torch.any(torch.isnan(output))

# Compare with reference output
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)
5 changes: 3 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2694,8 +2694,8 @@ class LoRAConfig:
lora_extra_vocab_size: int = 256
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
(added to the base model vocabulary)."""
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
lora_vocab_padding_size: ClassVar[int] = current_platform\
.get_lora_vocab_padding_size()
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
Expand Down Expand Up @@ -2723,6 +2723,7 @@ def compute_hash(self) -> str:
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size)
factors.append(self.long_lora_scaling_factors)
factors.append(self.bias_enabled)
hash_str = hashlib.md5(str(factors).encode(),
Expand Down
39 changes: 29 additions & 10 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
from vllm.platforms import current_platform

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -57,15 +58,25 @@ def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
device=x.device,
)

layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink(
buffers, x, layer.lora_a_stacked, 1.0)

if not current_platform.can_update_inplace():
buffers = shrunk_buffers

buffers = tensor_model_parallel_all_gather(buffers)
layer.punica_wrapper.add_expand(output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)

lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand(
output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)

if not current_platform.can_update_inplace():
output = lora_output

output = output.view(*out_orig_shape)
# now have column partitioned and packed output
Expand Down Expand Up @@ -292,7 +303,11 @@ def apply(self,
device=x.device,
)

self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink(
buffer, x, self.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace():
buffer = shrunk_buffer

buffer = tensor_model_parallel_all_reduce(buffer)

# following S-LoRA, allows the fusing of all_gather and all_reduce
Expand All @@ -304,7 +319,7 @@ def apply(self,
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
self.punica_wrapper.add_expand(
lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
Expand All @@ -313,6 +328,10 @@ def apply(self,
offset_start=offset_start,
add_input=True,
)

if not current_platform.can_update_inplace():
output = lora_output

output = output.view(*out_orig_shape)
return output

Expand Down
53 changes: 37 additions & 16 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
full_lora_a_embeddings.shape[1],
-1,
)
self.punica_wrapper.add_lora_embedding(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)

lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_embedding(
full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)

if not current_platform.can_update_inplace():
full_output = lora_output

return full_output.view_as(full_output_org)

@classmethod
Expand Down Expand Up @@ -410,10 +417,13 @@ def apply(self,
output = output.flatten(0, 1)
x = x.flatten(0, 1)

self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
self.lora_b_stacked,
self.lora_bias_stacked, 1.0,
self.output_slices)
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_linear(
output, x, self.lora_a_stacked, self.lora_b_stacked,
self.lora_bias_stacked, 1.0, self.output_slices)
if not current_platform.can_update_inplace():
output = lora_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


return output

@property
Expand Down Expand Up @@ -1133,15 +1143,23 @@ def _get_logits(
torch.matmul(self.embeddings_tensors,
hidden_states.T,
out=lora_logits[:-1])
lora_logits[-1] = float("-inf")

neg_inf, pos_inf = current_platform.get_infinity_values(
lora_logits.dtype)

lora_logits[-1] = neg_inf
lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded

if current_platform.is_tpu():
indices_padded = indices_padded[:logits.size(0)]

lora_logits = (lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
posinf=pos_inf,
neginf=neg_inf))

# HPU needs special handling to prune out dummy samples.
if current_platform.is_hpu():
Expand All @@ -1151,10 +1169,13 @@ def _get_logits(
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits

# LogitsProcessorWithLoRA always using bgmv
self.punica_wrapper.add_lora_logits(logits, hidden_states,
self.lora_a_stacked,
self.lora_b_stacked, 1.0)
lora_output: Optional[
torch.Tensor] = self.punica_wrapper.add_lora_logits(
logits, hidden_states, self.lora_a_stacked,
self.lora_b_stacked, 1.0)

if not current_platform.can_update_inplace():
logits = lora_output

# Remove paddings in vocab (if any).
logits = logits[:, :self.base_layer.vocab_size]
Expand Down
Loading