Skip to content

Commit 73869e1

Browse files
kebe7junamd-xiaoyu12
authored andcommitted
[Feature][Responses API] Support logprobs(non-stream) (vllm-project#23319)
Signed-off-by: Kebe <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 94def7b commit 73869e1

File tree

3 files changed

+86
-4
lines changed

3 files changed

+86
-4
lines changed

tests/v1/entrypoints/openai/responses/test_basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,16 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI):
7373
], )
7474
print(response)
7575
assert response.status == "completed"
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_logprobs(client: openai.AsyncOpenAI):
80+
response = await client.responses.create(
81+
include=["message.output_text.logprobs"],
82+
input="What is 13 * 24?",
83+
top_logprobs=5,
84+
)
85+
print(response)
86+
outputs = response.output
87+
assert outputs[-1].content[-1].logprobs
88+
assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5

vllm/entrypoints/openai/protocol.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,13 +357,22 @@ def to_sampling_params(
357357
temperature=temperature,
358358
top_p=top_p,
359359
max_tokens=max_tokens,
360-
logprobs=self.top_logprobs,
360+
logprobs=self.top_logprobs
361+
if self.is_include_output_logprobs() else None,
361362
stop_token_ids=stop_token_ids,
362363
output_kind=(RequestOutputKind.DELTA
363364
if self.stream else RequestOutputKind.FINAL_ONLY),
364365
guided_decoding=guided_decoding,
365366
)
366367

368+
def is_include_output_logprobs(self) -> bool:
369+
"""Check if the request includes output logprobs."""
370+
if self.include is None:
371+
return False
372+
return isinstance(
373+
self.include,
374+
list) and "message.output_text.logprobs" in self.include
375+
367376
@model_validator(mode="before")
368377
def validate_background(cls, data):
369378
if not data.get("background"):
@@ -1808,7 +1817,7 @@ class ResponsesResponse(OpenAIBaseModel):
18081817
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
18091818
status: ResponseStatus
18101819
text: Optional[ResponseTextConfig] = None
1811-
top_logprobs: int
1820+
top_logprobs: Optional[int] = None
18121821
truncation: Literal["auto", "disabled"]
18131822
usage: Optional[ResponseUsage] = None
18141823
user: Optional[str] = None

vllm/entrypoints/openai/serving_responses.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import asyncio
55
import json
66
import time
7-
from collections.abc import AsyncGenerator, AsyncIterator
7+
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
88
from contextlib import AsyncExitStack
99
from copy import copy
1010
from http import HTTPStatus
@@ -25,6 +25,8 @@
2525
ResponseReasoningItem,
2626
ResponseReasoningTextDeltaEvent,
2727
ResponseReasoningTextDoneEvent)
28+
from openai.types.responses.response_output_text import (Logprob,
29+
LogprobTopLogprob)
2830
# yapf: enable
2931
from openai.types.responses.response_reasoning_item import (
3032
Content as ResponseReasoningTextContent)
@@ -59,6 +61,8 @@
5961
from vllm.outputs import CompletionOutput
6062
from vllm.reasoning import ReasoningParser, ReasoningParserManager
6163
from vllm.sampling_params import SamplingParams
64+
from vllm.sequence import Logprob as SampleLogprob
65+
from vllm.sequence import SampleLogprobs
6266
from vllm.transformers_utils.tokenizer import AnyTokenizer
6367
from vllm.utils import random_uuid
6468

@@ -201,6 +205,12 @@ async def create_responses(
201205
# (i.e., their request's `store=True` just because it's the default
202206
# value).
203207
request.store = False
208+
if self.use_harmony and request.is_include_output_logprobs():
209+
return self.create_error_response(
210+
err_type="invalid_request_error",
211+
message="logprobs are not supported with gpt-oss models",
212+
status_code=HTTPStatus.BAD_REQUEST,
213+
)
204214

205215
# Handle the previous response ID.
206216
prev_response_id = request.previous_response_id
@@ -491,6 +501,51 @@ async def responses_full_generator(
491501
self.response_store[response.id] = response
492502
return response
493503

504+
def _topk_logprobs(self, logprobs: dict[int,
505+
SampleLogprob], top_logprobs: int,
506+
tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]:
507+
"""Returns the top-k logprobs from the logprobs dictionary."""
508+
out = []
509+
for i, (token_id, _logprob) in enumerate(logprobs.items()):
510+
if i >= top_logprobs:
511+
break
512+
text = _logprob.decoded_token if _logprob.decoded_token \
513+
is not None else tokenizer.decode([token_id])
514+
out.append(
515+
LogprobTopLogprob(
516+
token=text,
517+
logprob=max(_logprob.logprob, -9999.0),
518+
bytes=list(text.encode("utf-8", errors="replace")),
519+
))
520+
return out
521+
522+
def _create_response_logprobs(
523+
self,
524+
token_ids: Sequence[int],
525+
logprobs: Optional[SampleLogprobs],
526+
tokenizer: AnyTokenizer,
527+
top_logprobs: Optional[int] = None) -> list[Logprob]:
528+
assert logprobs is not None, "logprobs must be provided"
529+
assert len(token_ids) == len(logprobs), (
530+
"token_ids and logprobs.token_ids must have the same length")
531+
out = []
532+
for i, token_id in enumerate(token_ids):
533+
logprob = logprobs[i]
534+
token_logprob = logprob[token_id]
535+
text = token_logprob.decoded_token if token_logprob.decoded_token \
536+
is not None else tokenizer.decode([token_id])
537+
out.append(
538+
Logprob(
539+
token=text,
540+
logprob=max(token_logprob.logprob, -9999.0),
541+
bytes=list(text.encode("utf-8", errors="replace")),
542+
top_logprobs=self._topk_logprobs(logprob,
543+
top_logprobs=top_logprobs,
544+
tokenizer=tokenizer)
545+
if top_logprobs else [],
546+
))
547+
return out
548+
494549
def _make_response_output_items(
495550
self,
496551
request: ResponsesRequest,
@@ -547,7 +602,12 @@ def _make_response_output_items(
547602
text=content,
548603
annotations=[], # TODO
549604
type="output_text",
550-
logprobs=None, # TODO
605+
logprobs=self._create_response_logprobs(
606+
token_ids=final_output.token_ids,
607+
logprobs=final_output.logprobs,
608+
tokenizer=tokenizer,
609+
top_logprobs=request.top_logprobs,
610+
) if request.is_include_output_logprobs() else None,
551611
)
552612
message = ResponseOutputMessage(
553613
id=f"msg_{random_uuid()}",

0 commit comments

Comments
 (0)