|
4 | 4 | import asyncio
|
5 | 5 | import json
|
6 | 6 | import time
|
7 |
| -from collections.abc import AsyncGenerator, AsyncIterator |
| 7 | +from collections.abc import AsyncGenerator, AsyncIterator, Sequence |
8 | 8 | from contextlib import AsyncExitStack
|
9 | 9 | from copy import copy
|
10 | 10 | from http import HTTPStatus
|
|
25 | 25 | ResponseReasoningItem,
|
26 | 26 | ResponseReasoningTextDeltaEvent,
|
27 | 27 | ResponseReasoningTextDoneEvent)
|
| 28 | +from openai.types.responses.response_output_text import (Logprob, |
| 29 | + LogprobTopLogprob) |
28 | 30 | # yapf: enable
|
29 | 31 | from openai.types.responses.response_reasoning_item import (
|
30 | 32 | Content as ResponseReasoningTextContent)
|
|
59 | 61 | from vllm.outputs import CompletionOutput
|
60 | 62 | from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
61 | 63 | from vllm.sampling_params import SamplingParams
|
| 64 | +from vllm.sequence import Logprob as SampleLogprob |
| 65 | +from vllm.sequence import SampleLogprobs |
62 | 66 | from vllm.transformers_utils.tokenizer import AnyTokenizer
|
63 | 67 | from vllm.utils import random_uuid
|
64 | 68 |
|
@@ -201,6 +205,12 @@ async def create_responses(
|
201 | 205 | # (i.e., their request's `store=True` just because it's the default
|
202 | 206 | # value).
|
203 | 207 | request.store = False
|
| 208 | + if self.use_harmony and request.is_include_output_logprobs(): |
| 209 | + return self.create_error_response( |
| 210 | + err_type="invalid_request_error", |
| 211 | + message="logprobs are not supported with gpt-oss models", |
| 212 | + status_code=HTTPStatus.BAD_REQUEST, |
| 213 | + ) |
204 | 214 |
|
205 | 215 | # Handle the previous response ID.
|
206 | 216 | prev_response_id = request.previous_response_id
|
@@ -491,6 +501,51 @@ async def responses_full_generator(
|
491 | 501 | self.response_store[response.id] = response
|
492 | 502 | return response
|
493 | 503 |
|
| 504 | + def _topk_logprobs(self, logprobs: dict[int, |
| 505 | + SampleLogprob], top_logprobs: int, |
| 506 | + tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]: |
| 507 | + """Returns the top-k logprobs from the logprobs dictionary.""" |
| 508 | + out = [] |
| 509 | + for i, (token_id, _logprob) in enumerate(logprobs.items()): |
| 510 | + if i >= top_logprobs: |
| 511 | + break |
| 512 | + text = _logprob.decoded_token if _logprob.decoded_token \ |
| 513 | + is not None else tokenizer.decode([token_id]) |
| 514 | + out.append( |
| 515 | + LogprobTopLogprob( |
| 516 | + token=text, |
| 517 | + logprob=max(_logprob.logprob, -9999.0), |
| 518 | + bytes=list(text.encode("utf-8", errors="replace")), |
| 519 | + )) |
| 520 | + return out |
| 521 | + |
| 522 | + def _create_response_logprobs( |
| 523 | + self, |
| 524 | + token_ids: Sequence[int], |
| 525 | + logprobs: Optional[SampleLogprobs], |
| 526 | + tokenizer: AnyTokenizer, |
| 527 | + top_logprobs: Optional[int] = None) -> list[Logprob]: |
| 528 | + assert logprobs is not None, "logprobs must be provided" |
| 529 | + assert len(token_ids) == len(logprobs), ( |
| 530 | + "token_ids and logprobs.token_ids must have the same length") |
| 531 | + out = [] |
| 532 | + for i, token_id in enumerate(token_ids): |
| 533 | + logprob = logprobs[i] |
| 534 | + token_logprob = logprob[token_id] |
| 535 | + text = token_logprob.decoded_token if token_logprob.decoded_token \ |
| 536 | + is not None else tokenizer.decode([token_id]) |
| 537 | + out.append( |
| 538 | + Logprob( |
| 539 | + token=text, |
| 540 | + logprob=max(token_logprob.logprob, -9999.0), |
| 541 | + bytes=list(text.encode("utf-8", errors="replace")), |
| 542 | + top_logprobs=self._topk_logprobs(logprob, |
| 543 | + top_logprobs=top_logprobs, |
| 544 | + tokenizer=tokenizer) |
| 545 | + if top_logprobs else [], |
| 546 | + )) |
| 547 | + return out |
| 548 | + |
494 | 549 | def _make_response_output_items(
|
495 | 550 | self,
|
496 | 551 | request: ResponsesRequest,
|
@@ -547,7 +602,12 @@ def _make_response_output_items(
|
547 | 602 | text=content,
|
548 | 603 | annotations=[], # TODO
|
549 | 604 | type="output_text",
|
550 |
| - logprobs=None, # TODO |
| 605 | + logprobs=self._create_response_logprobs( |
| 606 | + token_ids=final_output.token_ids, |
| 607 | + logprobs=final_output.logprobs, |
| 608 | + tokenizer=tokenizer, |
| 609 | + top_logprobs=request.top_logprobs, |
| 610 | + ) if request.is_include_output_logprobs() else None, |
551 | 611 | )
|
552 | 612 | message = ResponseOutputMessage(
|
553 | 613 | id=f"msg_{random_uuid()}",
|
|
0 commit comments