Skip to content

[Core] Add Lora Support to Beam Search #18346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/entrypoints/openai/test_lora_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,37 @@ async def run_good_requests(client):
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
)


@pytest.mark.asyncio
async def test_beam_search_with_lora_adapters(
client: openai.AsyncOpenAI,
tmp_path,
zephyr_lora_files,
):
"""Validate that async beam search can be used with lora."""

async def load_and_run_adapter(adapter_name: str):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": adapter_name,
"lora_path": str(zephyr_lora_files)
})
for _ in range(3):
await client.completions.create(
model=adapter_name,
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
extra_body=dict(use_beam_search=True),
)

lora_tasks = []
for i in range(3):
lora_tasks.append(
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))

results, _ = await asyncio.wait(lora_tasks)

for r in results:
assert not isinstance(r, Exception), f"Got exception {r}"
62 changes: 60 additions & 2 deletions tests/lora/test_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams


@pytest.fixture(autouse=not current_platform.is_cpu())
Expand Down Expand Up @@ -69,7 +70,7 @@ def run_test(self,
expected_outputs: list[str],
lora_id: Optional[int] = None,
temperature: float = 0,
max_tokens: int = 5) -> list[str]:
max_tokens: int = 5):

sampling_params = vllm.SamplingParams(
temperature=temperature,
Expand Down Expand Up @@ -97,7 +98,35 @@ def run_test(self,
generated), f"Generated text {generated} doesn't "
f"match expected pattern {expected}"

return generated_texts
def run_beam_search_test(self,
images: list[ImageAsset],
expected_outputs: list[list[str]],
lora_id: Optional[int] = None,
temperature: float = 0,
beam_width: int = 2,
max_tokens: int = 5):

beam_search_params = BeamSearchParams(beam_width=beam_width,
max_tokens=max_tokens,
temperature=temperature)

inputs = [{
"prompt": self.PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in images]

lora_request = LoRARequest(str(lora_id), lora_id,
self.config.lora_path)
outputs = self.llm.beam_search(inputs,
beam_search_params,
lora_request=lora_request)

for output_obj, expected_outs in zip(outputs, expected_outputs):
output_texts = [seq.text for seq in output_obj.sequences]
assert output_texts == expected_outs, \
f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501


TEST_IMAGES = [
Expand All @@ -110,6 +139,14 @@ def run_test(self,
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
]

# NOTE - beam search .text contains the whole text
EXPECTED_BEAM_SEARCH_OUTPUTS = [
[
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501
],
]

QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"

Expand All @@ -130,6 +167,27 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
lora_id=lora_id)


@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm")
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
"""Test Qwen 2.0 VL model with LoRA through beam search."""
config = TestConfig(model_path=QWEN2VL_MODEL_PATH,
lora_path=qwen2vl_lora_files)
tester = Qwen2VLTester(config)

# Test with different LoRA IDs
for lora_id in [1, 2]:
# NOTE currently, we only test cherry blossom since stop sign
# output is slightly different for v1; - the root cause is likely
# independent of the intent of this test, which is to ensure beam
# search passes through lora through correctly.
tester.run_beam_search_test(
[ImageAsset("cherry_blossom")],
expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS,
lora_id=lora_id)


@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
Expand Down
4 changes: 4 additions & 0 deletions vllm/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union

from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob

if TYPE_CHECKING:
Expand All @@ -19,6 +20,7 @@ class BeamSearchSequence:
# The tokens includes the prompt.
tokens: list[int]
logprobs: list[dict[int, Logprob]]
lora_request: Optional[LoRARequest] = None
cum_logprob: float = 0.0
text: Optional[str] = None
finish_reason: Optional[str] = None
Expand All @@ -41,13 +43,15 @@ class BeamSearchInstance:
def __init__(
self,
prompt_tokens: list[int],
lora_request: Optional[LoRARequest] = None,
logprobs: Optional[list[dict[int, Logprob]]] = None,
**kwargs,
):
self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(
tokens=prompt_tokens,
logprobs=[] if logprobs is None else list(logprobs),
lora_request=lora_request,
**kwargs,
)
]
Expand Down
22 changes: 14 additions & 8 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async def beam_search(
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
lora_request: Optional[LoRARequest] = None,
) -> AsyncGenerator[RequestOutput, None]:

beam_width = params.beam_width
Expand Down Expand Up @@ -106,27 +107,31 @@ async def beam_search(
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request)
]
completed = []

for _ in range(max_tokens):
prompts_batch = [
prompts_batch, lora_req_batch = zip(*[(
TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs)
for beam in all_beams
]
mm_processor_kwargs=beam.mm_processor_kwargs),
beam.lora_request,
) for beam in all_beams])

tasks = []

request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
for i, (individual_prompt,
lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
self.generate(individual_prompt,
beam_search_params,
request_id_item,
lora_request=lora_req)))
tasks.append(task)

output = await asyncio.gather(*tasks)
Expand Down Expand Up @@ -159,6 +164,7 @@ async def beam_search(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.
Expand Down
42 changes: 36 additions & 6 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,28 @@ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
executor = self.llm_engine.model_executor
return executor.apply_model(func)

def _get_beam_search_lora_requests(
self,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
prompts: list[Union[TokensPrompt, TextPrompt]],
) -> list[Optional[LoRARequest]]:
"""Get the optional lora request corresponding to each prompt."""
if isinstance(lora_request,
Sequence) and len(lora_request) != len(prompts):
raise ValueError(
"Lora request list should be the same length as the prompts")
return lora_request

if lora_request is None or isinstance(lora_request, LoRARequest):
return [lora_request] * len(prompts)

raise TypeError(f"Invalid lora_request type {type(lora_request)}")

def beam_search(
self,
prompts: list[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[BeamSearchOutput]:
"""
Generate sequences using beam search.
Expand All @@ -531,6 +549,7 @@ def beam_search(
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
params: The beam search parameters.
lora_request: LoRA request to use for generation, if any.
"""
# TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.?
Expand All @@ -540,6 +559,9 @@ def beam_search(
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty

lora_requests = self._get_beam_search_lora_requests(
lora_request, prompts)

def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
Expand Down Expand Up @@ -567,7 +589,7 @@ def create_tokens_prompt_from_beam(
temperature=temperature)
instances: list[BeamSearchInstance] = []

for prompt in prompts:
for lora_req, prompt in zip(lora_requests, prompts):
# Add multimodal processor kwargs & data
mm_kwargs = {}
if "multi_modal_data" in prompt:
Expand All @@ -583,7 +605,12 @@ def create_tokens_prompt_from_beam(
prompt_tokens = tokenizer.encode(prompt["prompt"])

instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
BeamSearchInstance(
prompt_tokens,
lora_request=lora_req,
logprobs=None,
**mm_kwargs,
), )

for _ in range(max_tokens):
all_beams: list[BeamSearchSequence] = list(
Expand All @@ -597,15 +624,17 @@ def create_tokens_prompt_from_beam(
if len(all_beams) == 0:
break

prompts_batch = [
create_tokens_prompt_from_beam(beam) for beam in all_beams
]
# create the corresponding batch entries for prompt & optional lora
prompts_batch, lora_req_batch = zip(
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
for beam in all_beams])

# only runs for one step
# we don't need to use tqdm here
output = self.generate(prompts_batch,
sampling_params=beam_search_params,
use_tqdm=False)
use_tqdm=False,
lora_request=lora_req_batch)

for (start, end), instance in zip(instance_start_and_end,
instances):
Expand All @@ -623,6 +652,7 @@ def create_tokens_prompt_from_beam(
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ async def create_chat_completion(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
generator = self.engine_client.generate(
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ async def create_completion(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
generator = self.engine_client.generate(
Expand Down