Skip to content

Commit 2da706c

Browse files
alex-jw-brooksamitm02
authored andcommitted
[Core] Add Lora Support to Beam Search (vllm-project#18346)
Signed-off-by: Alex-Brooks <[email protected]> Signed-off-by: amit <[email protected]>
1 parent 73bcf46 commit 2da706c

File tree

7 files changed

+150
-16
lines changed

7 files changed

+150
-16
lines changed

tests/entrypoints/openai/test_lora_adapters.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,37 @@ async def run_good_requests(client):
313313
prompt=["Hello there", "Foo bar bazz buzz"],
314314
max_tokens=5,
315315
)
316+
317+
318+
@pytest.mark.asyncio
319+
async def test_beam_search_with_lora_adapters(
320+
client: openai.AsyncOpenAI,
321+
tmp_path,
322+
zephyr_lora_files,
323+
):
324+
"""Validate that async beam search can be used with lora."""
325+
326+
async def load_and_run_adapter(adapter_name: str):
327+
await client.post("load_lora_adapter",
328+
cast_to=str,
329+
body={
330+
"lora_name": adapter_name,
331+
"lora_path": str(zephyr_lora_files)
332+
})
333+
for _ in range(3):
334+
await client.completions.create(
335+
model=adapter_name,
336+
prompt=["Hello there", "Foo bar bazz buzz"],
337+
max_tokens=5,
338+
extra_body=dict(use_beam_search=True),
339+
)
340+
341+
lora_tasks = []
342+
for i in range(3):
343+
lora_tasks.append(
344+
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
345+
346+
results, _ = await asyncio.wait(lora_tasks)
347+
348+
for r in results:
349+
assert not isinstance(r, Exception), f"Got exception {r}"

tests/lora/test_qwen2vl.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.assets.image import ImageAsset
1111
from vllm.lora.request import LoRARequest
1212
from vllm.platforms import current_platform
13+
from vllm.sampling_params import BeamSearchParams
1314

1415

1516
@pytest.fixture(autouse=not current_platform.is_cpu())
@@ -69,7 +70,7 @@ def run_test(self,
6970
expected_outputs: list[str],
7071
lora_id: Optional[int] = None,
7172
temperature: float = 0,
72-
max_tokens: int = 5) -> list[str]:
73+
max_tokens: int = 5):
7374

7475
sampling_params = vllm.SamplingParams(
7576
temperature=temperature,
@@ -97,7 +98,35 @@ def run_test(self,
9798
generated), f"Generated text {generated} doesn't "
9899
f"match expected pattern {expected}"
99100

100-
return generated_texts
101+
def run_beam_search_test(self,
102+
images: list[ImageAsset],
103+
expected_outputs: list[list[str]],
104+
lora_id: Optional[int] = None,
105+
temperature: float = 0,
106+
beam_width: int = 2,
107+
max_tokens: int = 5):
108+
109+
beam_search_params = BeamSearchParams(beam_width=beam_width,
110+
max_tokens=max_tokens,
111+
temperature=temperature)
112+
113+
inputs = [{
114+
"prompt": self.PROMPT_TEMPLATE,
115+
"multi_modal_data": {
116+
"image": asset.pil_image
117+
},
118+
} for asset in images]
119+
120+
lora_request = LoRARequest(str(lora_id), lora_id,
121+
self.config.lora_path)
122+
outputs = self.llm.beam_search(inputs,
123+
beam_search_params,
124+
lora_request=lora_request)
125+
126+
for output_obj, expected_outs in zip(outputs, expected_outputs):
127+
output_texts = [seq.text for seq in output_obj.sequences]
128+
assert output_texts == expected_outs, \
129+
f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501
101130

102131

103132
TEST_IMAGES = [
@@ -110,6 +139,14 @@ def run_test(self,
110139
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
111140
]
112141

142+
# NOTE - beam search .text contains the whole text
143+
EXPECTED_BEAM_SEARCH_OUTPUTS = [
144+
[
145+
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501
146+
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501
147+
],
148+
]
149+
113150
QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
114151
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
115152

@@ -130,6 +167,27 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
130167
lora_id=lora_id)
131168

132169

170+
@pytest.mark.xfail(
171+
current_platform.is_rocm(),
172+
reason="Qwen2-VL dependency xformers incompatible with ROCm")
173+
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
174+
"""Test Qwen 2.0 VL model with LoRA through beam search."""
175+
config = TestConfig(model_path=QWEN2VL_MODEL_PATH,
176+
lora_path=qwen2vl_lora_files)
177+
tester = Qwen2VLTester(config)
178+
179+
# Test with different LoRA IDs
180+
for lora_id in [1, 2]:
181+
# NOTE currently, we only test cherry blossom since stop sign
182+
# output is slightly different for v1; - the root cause is likely
183+
# independent of the intent of this test, which is to ensure beam
184+
# search passes through lora through correctly.
185+
tester.run_beam_search_test(
186+
[ImageAsset("cherry_blossom")],
187+
expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS,
188+
lora_id=lora_id)
189+
190+
133191
@pytest.mark.xfail(
134192
current_platform.is_rocm(),
135193
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",

vllm/beam_search.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dataclasses import dataclass
44
from typing import TYPE_CHECKING, Any, Optional, Union
55

6+
from vllm.lora.request import LoRARequest
67
from vllm.sequence import Logprob
78

89
if TYPE_CHECKING:
@@ -19,6 +20,7 @@ class BeamSearchSequence:
1920
# The tokens includes the prompt.
2021
tokens: list[int]
2122
logprobs: list[dict[int, Logprob]]
23+
lora_request: Optional[LoRARequest] = None
2224
cum_logprob: float = 0.0
2325
text: Optional[str] = None
2426
finish_reason: Optional[str] = None
@@ -41,13 +43,15 @@ class BeamSearchInstance:
4143
def __init__(
4244
self,
4345
prompt_tokens: list[int],
46+
lora_request: Optional[LoRARequest] = None,
4447
logprobs: Optional[list[dict[int, Logprob]]] = None,
4548
**kwargs,
4649
):
4750
self.beams: list[BeamSearchSequence] = [
4851
BeamSearchSequence(
4952
tokens=prompt_tokens,
5053
logprobs=[] if logprobs is None else list(logprobs),
54+
lora_request=lora_request,
5155
**kwargs,
5256
)
5357
]

vllm/engine/protocol.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ async def beam_search(
6565
prompt: PromptType,
6666
request_id: str,
6767
params: BeamSearchParams,
68+
lora_request: Optional[LoRARequest] = None,
6869
) -> AsyncGenerator[RequestOutput, None]:
6970

7071
beam_width = params.beam_width
@@ -106,27 +107,31 @@ async def beam_search(
106107
cum_logprob=0,
107108
logprobs=[],
108109
multi_modal_data=multi_modal_data,
109-
mm_processor_kwargs=mm_processor_kwargs)
110+
mm_processor_kwargs=mm_processor_kwargs,
111+
lora_request=lora_request)
110112
]
111113
completed = []
112114

113115
for _ in range(max_tokens):
114-
prompts_batch = [
116+
prompts_batch, lora_req_batch = zip(*[(
115117
TokensPrompt(prompt_token_ids=beam.tokens,
116118
multi_modal_data=beam.multi_modal_data,
117-
mm_processor_kwargs=beam.mm_processor_kwargs)
118-
for beam in all_beams
119-
]
119+
mm_processor_kwargs=beam.mm_processor_kwargs),
120+
beam.lora_request,
121+
) for beam in all_beams])
120122

121123
tasks = []
122124

123125
request_id = f"beam_search-{random_uuid()}"
124-
for i, individual_prompt in enumerate(prompts_batch):
126+
for i, (individual_prompt,
127+
lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
125128
request_id_item = f"{request_id}-{i}"
126129
task = asyncio.create_task(
127130
collect_from_async_generator(
128-
self.generate(individual_prompt, beam_search_params,
129-
request_id_item)))
131+
self.generate(individual_prompt,
132+
beam_search_params,
133+
request_id_item,
134+
lora_request=lora_req)))
130135
tasks.append(task)
131136

132137
output = await asyncio.gather(*tasks)
@@ -159,6 +164,7 @@ async def beam_search(
159164
tokens=current_beam.tokens + [token_id],
160165
logprobs=current_beam.logprobs +
161166
[logprobs],
167+
lora_request=current_beam.lora_request,
162168
cum_logprob=current_beam.cum_logprob +
163169
logprob_obj.logprob,
164170
multi_modal_data=current_beam.

vllm/entrypoints/llm.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,28 @@ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
522522
executor = self.llm_engine.model_executor
523523
return executor.apply_model(func)
524524

525+
def _get_beam_search_lora_requests(
526+
self,
527+
lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
528+
prompts: list[Union[TokensPrompt, TextPrompt]],
529+
) -> list[Optional[LoRARequest]]:
530+
"""Get the optional lora request corresponding to each prompt."""
531+
if isinstance(lora_request,
532+
Sequence) and len(lora_request) != len(prompts):
533+
raise ValueError(
534+
"Lora request list should be the same length as the prompts")
535+
return lora_request
536+
537+
if lora_request is None or isinstance(lora_request, LoRARequest):
538+
return [lora_request] * len(prompts)
539+
540+
raise TypeError(f"Invalid lora_request type {type(lora_request)}")
541+
525542
def beam_search(
526543
self,
527544
prompts: list[Union[TokensPrompt, TextPrompt]],
528545
params: BeamSearchParams,
546+
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
529547
) -> list[BeamSearchOutput]:
530548
"""
531549
Generate sequences using beam search.
@@ -534,6 +552,7 @@ def beam_search(
534552
prompts: A list of prompts. Each prompt can be a string or a list
535553
of token IDs.
536554
params: The beam search parameters.
555+
lora_request: LoRA request to use for generation, if any.
537556
"""
538557
# TODO: how does beam search work together with length penalty,
539558
# frequency, penalty, and stopping criteria, etc.?
@@ -543,6 +562,9 @@ def beam_search(
543562
ignore_eos = params.ignore_eos
544563
length_penalty = params.length_penalty
545564

565+
lora_requests = self._get_beam_search_lora_requests(
566+
lora_request, prompts)
567+
546568
def sort_beams_key(x: BeamSearchSequence) -> float:
547569
return get_beam_search_score(x.tokens, x.cum_logprob,
548570
tokenizer.eos_token_id,
@@ -570,7 +592,7 @@ def create_tokens_prompt_from_beam(
570592
temperature=temperature)
571593
instances: list[BeamSearchInstance] = []
572594

573-
for prompt in prompts:
595+
for lora_req, prompt in zip(lora_requests, prompts):
574596
# Add multimodal processor kwargs & data
575597
mm_kwargs = {}
576598
if "multi_modal_data" in prompt:
@@ -586,7 +608,12 @@ def create_tokens_prompt_from_beam(
586608
prompt_tokens = tokenizer.encode(prompt["prompt"])
587609

588610
instances.append(
589-
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
611+
BeamSearchInstance(
612+
prompt_tokens,
613+
lora_request=lora_req,
614+
logprobs=None,
615+
**mm_kwargs,
616+
), )
590617

591618
for _ in range(max_tokens):
592619
all_beams: list[BeamSearchSequence] = list(
@@ -600,15 +627,17 @@ def create_tokens_prompt_from_beam(
600627
if len(all_beams) == 0:
601628
break
602629

603-
prompts_batch = [
604-
create_tokens_prompt_from_beam(beam) for beam in all_beams
605-
]
630+
# create the corresponding batch entries for prompt & optional lora
631+
prompts_batch, lora_req_batch = zip(
632+
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
633+
for beam in all_beams])
606634

607635
# only runs for one step
608636
# we don't need to use tqdm here
609637
output = self.generate(prompts_batch,
610638
sampling_params=beam_search_params,
611-
use_tqdm=False)
639+
use_tqdm=False,
640+
lora_request=lora_req_batch)
612641

613642
for (start, end), instance in zip(instance_start_and_end,
614643
instances):
@@ -626,6 +655,7 @@ def create_tokens_prompt_from_beam(
626655
new_beam = BeamSearchSequence(
627656
tokens=current_beam.tokens + [token_id],
628657
logprobs=current_beam.logprobs + [logprobs],
658+
lora_request=current_beam.lora_request,
629659
cum_logprob=current_beam.cum_logprob +
630660
logprob_obj.logprob,
631661
multi_modal_data=current_beam.multi_modal_data,

vllm/entrypoints/openai/serving_chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ async def create_chat_completion(
236236
prompt=engine_prompt,
237237
request_id=request_id,
238238
params=sampling_params,
239+
lora_request=lora_request,
239240
)
240241
else:
241242
generator = self.engine_client.generate(

vllm/entrypoints/openai/serving_completion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ async def create_completion(
186186
prompt=engine_prompt,
187187
request_id=request_id,
188188
params=sampling_params,
189+
lora_request=lora_request,
189190
)
190191
else:
191192
generator = self.engine_client.generate(

0 commit comments

Comments
 (0)