Skip to content

Commit 062c89e

Browse files
joerundenjhill
andauthored
[Frontend][Core] Move guided decoding params into sampling params (#8252)
Signed-off-by: Joe Runde <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent bce3244 commit 062c89e

File tree

16 files changed

+441
-281
lines changed

16 files changed

+441
-281
lines changed

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from vllm.entrypoints.llm import LLM
99
from vllm.outputs import RequestOutput
10-
from vllm.sampling_params import SamplingParams
10+
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1111

1212
from ...conftest import cleanup
1313

@@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm):
3131
sampling_params = SamplingParams(
3232
temperature=0.8,
3333
top_p=0.95,
34-
)
35-
outputs = llm.generate(
36-
prompts=[
37-
f"Give an example IPv4 address with this regex: {sample_regex}"
38-
] * 2,
39-
sampling_params=sampling_params,
40-
use_tqdm=True,
41-
guided_options_request=dict(guided_regex=sample_regex))
34+
guided_decoding=GuidedDecodingParams(regex=sample_regex))
35+
outputs = llm.generate(prompts=[
36+
f"Give an example IPv4 address with this regex: {sample_regex}"
37+
] * 2,
38+
sampling_params=sampling_params,
39+
use_tqdm=True)
4240

4341
assert outputs is not None
4442
for output in outputs:
@@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm):
5755
sampling_params = SamplingParams(
5856
temperature=1.0,
5957
max_tokens=1000,
60-
)
61-
outputs = llm.generate(
62-
prompts=[
63-
f"Give an example JSON for an employee profile "
64-
f"that fits this schema: {sample_json_schema}"
65-
] * 2,
66-
sampling_params=sampling_params,
67-
use_tqdm=True,
68-
guided_options_request=dict(guided_json=sample_json_schema))
58+
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
59+
outputs = llm.generate(prompts=[
60+
f"Give an example JSON for an employee profile "
61+
f"that fits this schema: {sample_json_schema}"
62+
] * 2,
63+
sampling_params=sampling_params,
64+
use_tqdm=True)
6965

7066
assert outputs is not None
7167

@@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm):
8682
sampling_params = SamplingParams(
8783
temperature=0.8,
8884
top_p=0.95,
89-
)
85+
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
9086
outputs = llm.generate(
9187
prompts="The best language for type-safe systems programming is ",
9288
sampling_params=sampling_params,
93-
use_tqdm=True,
94-
guided_options_request=dict(guided_choice=sample_guided_choice))
89+
use_tqdm=True)
9590

9691
assert outputs is not None
9792
for output in outputs:
@@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm):
112107
temperature=0.8,
113108
top_p=0.95,
114109
max_tokens=1000,
115-
)
110+
guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements))
116111
outputs = llm.generate(
117112
prompts=("Generate a sql state that select col_1 from "
118113
"table_1 where it is equals to 1"),
119114
sampling_params=sampling_params,
120115
use_tqdm=True,
121-
guided_options_request=dict(guided_grammar=sample_sql_statements))
116+
)
122117

123118
assert outputs is not None
124119
for output in outputs:
@@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm):
140135
assert generated_text.strip() == ground_truth
141136

142137
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
138+
139+
140+
@pytest.mark.skip_global_cleanup
141+
def test_guided_options_request_deprecation_warning(sample_regex, llm):
142+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
143+
144+
with pytest.warns(DeprecationWarning, match="guided_options_request"):
145+
llm.generate(prompts="This should fail",
146+
sampling_params=sampling_params,
147+
use_tqdm=True,
148+
guided_options_request=dict(guided_regex=sample_regex))
149+
150+
151+
@pytest.mark.skip_global_cleanup
152+
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
153+
sampling_params = SamplingParams(
154+
temperature=0.8,
155+
top_p=0.95,
156+
guided_decoding=GuidedDecodingParams(regex=sample_regex))
157+
158+
with pytest.raises(ValueError, match="Cannot set both"):
159+
llm.generate(prompts="This should fail",
160+
sampling_params=sampling_params,
161+
use_tqdm=True,
162+
guided_options_request=dict(guided_regex=sample_regex))

tests/model_executor/conftest.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
3+
4+
@pytest.fixture
5+
def sample_regex():
6+
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
7+
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
8+
9+
10+
@pytest.fixture
11+
def sample_json_schema():
12+
return {
13+
"type": "object",
14+
"properties": {
15+
"name": {
16+
"type": "string"
17+
},
18+
"age": {
19+
"type": "integer"
20+
},
21+
"skills": {
22+
"type": "array",
23+
"items": {
24+
"type": "string",
25+
"maxLength": 10
26+
},
27+
"minItems": 3
28+
},
29+
"work_history": {
30+
"type": "array",
31+
"items": {
32+
"type": "object",
33+
"properties": {
34+
"company": {
35+
"type": "string"
36+
},
37+
"duration": {
38+
"type": "number"
39+
},
40+
"position": {
41+
"type": "string"
42+
}
43+
},
44+
"required": ["company", "position"]
45+
}
46+
}
47+
},
48+
"required": ["name", "age", "skills", "work_history"]
49+
}

tests/entrypoints/openai/test_guided_processors.py renamed to tests/model_executor/test_guided_processors.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
# This unit test should be moved to a new
2-
# tests/test_guided_decoding directory.
31
import pytest
42
import torch
53
from transformers import AutoTokenizer
64

7-
from vllm.entrypoints.openai.protocol import CompletionRequest
85
from vllm.model_executor.guided_decoding import (
96
get_guided_decoding_logits_processor)
107
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
118
JSONLogitsProcessor, RegexLogitsProcessor)
9+
from vllm.sampling_params import GuidedDecodingParams
1210

1311

1412
def test_guided_logits_processors(sample_regex, sample_json_schema):
@@ -44,11 +42,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
4442
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
4543
token_ids = tokenizer.encode(
4644
f"Give an example IPv4 address with this regex: {sample_regex}")
47-
regex_request = CompletionRequest(model='test',
48-
prompt=token_ids,
49-
guided_regex=sample_regex)
45+
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
5046
regex_lp = await get_guided_decoding_logits_processor(
51-
backend, regex_request, tokenizer)
47+
regex_request, tokenizer)
5248
assert regex_lp is not None
5349
tensor = torch.rand(32000)
5450
original_tensor = torch.clone(tensor)
@@ -59,14 +55,31 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
5955
token_ids = tokenizer.encode(
6056
f"Give an employee profile that fits this schema: {sample_json_schema}"
6157
)
62-
json_request = CompletionRequest(model='test',
63-
prompt=token_ids,
64-
guided_json=sample_json_schema)
58+
json_request = GuidedDecodingParams(json=sample_json_schema,
59+
backend=backend)
6560
json_lp = await get_guided_decoding_logits_processor(
66-
backend, json_request, tokenizer)
61+
json_request, tokenizer)
6762
assert json_lp is not None
6863
tensor = torch.rand(32000)
6964
original_tensor = torch.clone(tensor)
7065
tensor = json_lp(token_ids, tensor)
7166
assert tensor.shape == original_tensor.shape
7267
assert not torch.allclose(tensor, original_tensor)
68+
69+
70+
def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
71+
with pytest.raises(ValueError,
72+
match="You can only use one kind of guided"):
73+
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)
74+
75+
with pytest.raises(ValueError,
76+
match="You can only use one kind of guided"):
77+
GuidedDecodingParams(json=sample_json_schema, json_object=True)
78+
79+
with pytest.raises(ValueError,
80+
match="You can only use one kind of guided"):
81+
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])
82+
83+
with pytest.raises(ValueError,
84+
match="You can only use one kind of guided"):
85+
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")

vllm/engine/async_llm_engine.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from vllm.inputs import PromptType
2121
from vllm.logger import init_logger
2222
from vllm.lora.request import LoRARequest
23+
from vllm.model_executor.guided_decoding import (
24+
get_guided_decoding_logits_processor)
2325
from vllm.model_executor.layers.sampler import SamplerOutput
2426
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
2527
from vllm.pooling_params import PoolingParams
@@ -477,6 +479,18 @@ async def add_request_async(
477479
)
478480
processed_inputs = self.input_processor(preprocessed_inputs)
479481

482+
if isinstance(params, SamplingParams) and \
483+
params.guided_decoding is not None:
484+
# Guided decoding has an async implementation for building logits
485+
# processors in a separate threadpool.
486+
# We want to invoke that here instead of using the blocking
487+
# implementation in the LLMEngine
488+
params = await build_guided_decoding_logits_processor_async(
489+
sampling_params=params,
490+
tokenizer=self.get_tokenizer(lora_request),
491+
default_guided_backend=self.decoding_config.
492+
guided_decoding_backend)
493+
480494
self._add_processed_request(
481495
request_id=request_id,
482496
processed_inputs=processed_inputs,
@@ -494,6 +508,36 @@ async def check_health_async(self) -> None:
494508
self.model_executor.check_health()
495509

496510

511+
async def build_guided_decoding_logits_processor_async(
512+
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
513+
default_guided_backend: str) -> SamplingParams:
514+
"""Constructs logits processors based on the guided_decoding,
515+
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
516+
those fields and adds the constructed logits processors to the
517+
logits_processors field. Modifies sampling params in-place and returns
518+
the modified sampling params."""
519+
if (guided_decoding := sampling_params.guided_decoding) is None:
520+
return sampling_params
521+
522+
logger.debug("Building guided decoding logits processor. "
523+
"Params: %s", guided_decoding)
524+
525+
guided_decoding.backend = guided_decoding.backend or default_guided_backend
526+
527+
processor = await get_guided_decoding_logits_processor(
528+
guided_params=guided_decoding, tokenizer=tokenizer)
529+
530+
if processor:
531+
if sampling_params.logits_processors is None:
532+
sampling_params.logits_processors = []
533+
sampling_params.logits_processors.append(processor)
534+
535+
# Unset guided decoding params after constructing the lp from them
536+
sampling_params.guided_decoding = None
537+
538+
return sampling_params
539+
540+
497541
class AsyncLLMEngine:
498542
"""An asynchronous wrapper for :class:`LLMEngine`.
499543

vllm/engine/llm_engine.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
SequenceGroupOutputProcessor)
2626
from vllm.engine.output_processor.stop_checker import StopChecker
2727
from vllm.engine.output_processor.util import create_output_by_sequence_group
28+
from vllm.entrypoints.openai.logits_processors import get_logits_processors
2829
from vllm.executor.executor_base import ExecutorBase
2930
from vllm.executor.gpu_executor import GPUExecutor
3031
from vllm.executor.ray_utils import initialize_ray_cluster
@@ -33,6 +34,8 @@
3334
from vllm.inputs.preprocess import InputPreprocessor
3435
from vllm.logger import init_logger
3536
from vllm.lora.request import LoRARequest
37+
from vllm.model_executor.guided_decoding import (
38+
get_local_guided_decoding_logits_processor)
3639
from vllm.model_executor.layers.sampler import SamplerOutput
3740
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
3841
RequestOutputFactory)
@@ -843,6 +846,9 @@ def _create_sequence_group_with_sampling(
843846
raise ValueError(f"Cannot request more than "
844847
f"{max_logprobs} logprobs.")
845848

849+
sampling_params = self._build_logits_processors(
850+
sampling_params, lora_request)
851+
846852
# Defensive copy of SamplingParams, which are used by the sampler,
847853
# this doesn't deep-copy LogitsProcessor objects
848854
sampling_params = sampling_params.clone()
@@ -1895,3 +1901,51 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs,
18951901
# TODO: Find out how many placeholder tokens are there so we can
18961902
# check that chunked prefill does not truncate them
18971903
# max_batch_len = self.scheduler_config.max_num_batched_tokens
1904+
1905+
def _build_logits_processors(
1906+
self, sampling_params: SamplingParams,
1907+
lora_request: Optional[LoRARequest]) -> SamplingParams:
1908+
"""Constructs logits processors based on the guided_decoding,
1909+
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
1910+
those fields and adds the constructed logits processors to the
1911+
logits_processors field. Returns the modified sampling params."""
1912+
1913+
logits_processors = []
1914+
if (guided_decoding := sampling_params.guided_decoding) is not None:
1915+
1916+
logger.debug(
1917+
"Building guided decoding logits processor in "
1918+
"LLMEngine. Params: %s", guided_decoding)
1919+
1920+
tokenizer = self.get_tokenizer(lora_request=lora_request)
1921+
guided_decoding.backend = guided_decoding.backend or \
1922+
self.decoding_config.guided_decoding_backend
1923+
1924+
processor = get_local_guided_decoding_logits_processor(
1925+
guided_params=guided_decoding, tokenizer=tokenizer)
1926+
if processor:
1927+
logits_processors.append(processor)
1928+
1929+
# Unset so this doesn't get passed down to the model
1930+
sampling_params.guided_decoding = None
1931+
1932+
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
1933+
tokenizer = self.get_tokenizer(lora_request=lora_request)
1934+
1935+
processors = get_logits_processors(
1936+
logit_bias=sampling_params.logit_bias,
1937+
allowed_token_ids=sampling_params.allowed_token_ids,
1938+
tokenizer=tokenizer)
1939+
logits_processors.extend(processors)
1940+
1941+
# Unset so these don't get passed down to the model
1942+
sampling_params.logit_bias = None
1943+
sampling_params.allowed_token_ids = None
1944+
1945+
if logits_processors:
1946+
if sampling_params.logits_processors is None:
1947+
sampling_params.logits_processors = logits_processors
1948+
else:
1949+
sampling_params.logits_processors.extend(logits_processors)
1950+
1951+
return sampling_params

0 commit comments

Comments
 (0)