-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[Feature] Add /tokenize and /detokenize OpenAI compatible endpoints #9545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e08c877
75f963d
2e8648a
7231a93
60cd111
7612300
fc25031
ba0bb79
6b0ca39
8cb4a0b
f95f828
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -629,12 +629,50 @@ class RerankResponse(BaseModel): | |
meta_info: Optional[dict] = None | ||
|
||
|
||
class TokenizeRequest(BaseModel): | ||
"""Request schema for the /tokenize endpoint.""" | ||
|
||
model: str | ||
adarshxs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
prompt: Union[str, List[str]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we keep the batched option? cc @slin1237 @CatherineSue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can keep it as this is not an official OpenAI endpoint, and it directly uses tokenizer, so no performance or compatibility concerns. |
||
add_special_tokens: bool = Field( | ||
default=True, | ||
description="whether to add model-specific special tokens (e.g. BOS/EOS) during encoding.", | ||
) | ||
|
||
|
||
class TokenizeResponse(BaseModel): | ||
"""Response schema for the /tokenize endpoint.""" | ||
|
||
tokens: Union[List[int], List[List[int]]] | ||
count: Union[int, List[int]] | ||
max_model_len: int | ||
|
||
|
||
class DetokenizeRequest(BaseModel): | ||
"""Request schema for the /detokenize endpoint.""" | ||
|
||
model: str | ||
adarshxs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
tokens: Union[List[int], List[List[int]]] | ||
skip_special_tokens: bool = Field( | ||
default=True, | ||
description="whether to exclude special tokens (e.g. padding or EOS) during decoding.", | ||
) | ||
|
||
|
||
class DetokenizeResponse(BaseModel): | ||
"""Response schema for the /detokenize endpoint.""" | ||
|
||
text: Union[str, List[str]] | ||
|
||
|
||
OpenAIServingRequest = Union[ | ||
ChatCompletionRequest, | ||
CompletionRequest, | ||
EmbeddingRequest, | ||
ScoringRequest, | ||
V1RerankReqInput, | ||
TokenizeRequest, | ||
DetokenizeRequest, | ||
] | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import logging | ||
from http import HTTPStatus | ||
from typing import List, Union | ||
|
||
from fastapi import Request | ||
|
||
from sglang.srt.entrypoints.openai.protocol import ( | ||
DetokenizeRequest, | ||
DetokenizeResponse, | ||
ErrorResponse, | ||
TokenizeRequest, | ||
TokenizeResponse, | ||
) | ||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class OpenAIServingTokenize(OpenAIServingBase): | ||
"""Handler for /v1/tokenize requests""" | ||
|
||
def _request_id_prefix(self) -> str: | ||
return "tok-" | ||
|
||
def _convert_to_internal_request( | ||
self, request: TokenizeRequest | ||
) -> tuple[TokenizeRequest, TokenizeRequest]: | ||
return request, request | ||
|
||
async def _handle_non_streaming_request( | ||
self, | ||
adapted_request: TokenizeRequest, | ||
request: TokenizeRequest, | ||
raw_request: Request, | ||
) -> Union[TokenizeResponse, ErrorResponse]: | ||
try: | ||
tokenizer = self.tokenizer_manager.tokenizer | ||
max_model_len = getattr(tokenizer, "model_max_length", -1) | ||
|
||
if isinstance(request.prompt, str): | ||
token_ids = tokenizer.encode( | ||
request.prompt, | ||
add_special_tokens=request.add_special_tokens, | ||
) | ||
tokens = token_ids | ||
count = len(token_ids) | ||
elif isinstance(request.prompt, list): | ||
token_ids_list = [ | ||
tokenizer.encode( | ||
text, add_special_tokens=request.add_special_tokens | ||
) | ||
for text in request.prompt | ||
] | ||
tokens = token_ids_list | ||
count = [len(ids) for ids in token_ids_list] | ||
else: | ||
return self.create_error_response( | ||
f"Invalid prompt type: {type(request.prompt)}. Expected str or List[str]." | ||
) | ||
|
||
return TokenizeResponse( | ||
tokens=tokens, count=count, max_model_len=max_model_len | ||
) | ||
except Exception as e: | ||
logger.error("Error during tokenization", exc_info=True) | ||
return self.create_error_response( | ||
f"Internal server error during tokenization: {e}", | ||
err_type="InternalServerError", | ||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, | ||
) | ||
|
||
|
||
class OpenAIServingDetokenize(OpenAIServingBase): | ||
"""Handler for /v1/detokenize requests""" | ||
|
||
def _request_id_prefix(self) -> str: | ||
return "detok-" | ||
|
||
def _convert_to_internal_request( | ||
self, request: DetokenizeRequest | ||
) -> tuple[DetokenizeRequest, DetokenizeRequest]: | ||
return request, request | ||
|
||
async def _handle_non_streaming_request( | ||
self, | ||
adapted_request: DetokenizeRequest, | ||
request: DetokenizeRequest, | ||
raw_request: Request, | ||
) -> Union[DetokenizeResponse, ErrorResponse]: | ||
try: | ||
tokenizer = self.tokenizer_manager.tokenizer | ||
|
||
if ( | ||
isinstance(request.tokens, list) | ||
and request.tokens | ||
and isinstance(request.tokens[0], int) | ||
): | ||
if not all(isinstance(t, int) for t in request.tokens): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Should this be better removed to |
||
return self.create_error_response( | ||
"Invalid input: 'tokens' must be a list of integers." | ||
) | ||
tokens_to_decode = [int(t) for t in request.tokens] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Why do we need int(t) here? I assume the above if check already makes sure tokens are int? |
||
text = tokenizer.decode( | ||
tokens_to_decode, skip_special_tokens=request.skip_special_tokens | ||
) | ||
text_out: Union[str, List[str]] = text | ||
elif ( | ||
isinstance(request.tokens, list) | ||
and request.tokens | ||
and isinstance(request.tokens[0], list) | ||
): | ||
texts: List[str] = [] | ||
for token_list in request.tokens: | ||
if not all(isinstance(t, int) for t in token_list): | ||
return self.create_error_response( | ||
f"Invalid input: Sublist in 'tokens' must contain only integers. Found: {token_list}" | ||
) | ||
decoded_text = tokenizer.decode( | ||
[int(t) for t in token_list], | ||
skip_special_tokens=request.skip_special_tokens, | ||
) | ||
texts.append(decoded_text) | ||
adarshxs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
text_out = texts | ||
elif isinstance(request.tokens, list) and not request.tokens: | ||
text_out = "" | ||
else: | ||
return self.create_error_response( | ||
f"Invalid tokens type: {type(request.tokens)}. Expected List[int] or List[List[int]]." | ||
) | ||
|
||
return DetokenizeResponse(text=text_out) | ||
except Exception as e: | ||
logger.error("Error during detokenization", exc_info=True) | ||
if "decode" in str(e).lower(): | ||
return self.create_error_response( | ||
f"Error decoding tokens: {e}. Input tokens might be invalid for the model.", | ||
err_type="DecodeError", | ||
status_code=HTTPStatus.BAD_REQUEST, | ||
) | ||
return self.create_error_response( | ||
f"Internal server error during detokenization: {e}", | ||
err_type="InternalServerError", | ||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.