Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion docs/basic_usage/native_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"- `/start_expert_distribution_record`\n",
"- `/stop_expert_distribution_record`\n",
"- `/dump_expert_distribution_record`\n",
"- `/tokenize`\n",
"- `/detokenize`\n",
"- A full list of these APIs can be found at [http_server.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py)\n",
"\n",
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`.\n"
Expand Down Expand Up @@ -477,6 +479,104 @@
"source": [
"terminate_process(expert_record_server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenize/Detokenize Example (Round Trip)\n",
"\n",
"This example demonstrates how to use the /tokenize and /detokenize endpoints together. We first tokenize a string, then detokenize the resulting IDs to reconstruct the original text. This workflow is useful when you need to handle tokenization externally but still leverage the server for detokenization."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer_free_server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from sglang.utils import print_highlight\n",
"\n",
"base_url = f\"http://localhost:{port}\"\n",
"tokenize_url = f\"{base_url}/tokenize\"\n",
"detokenize_url = f\"{base_url}/detokenize\"\n",
"\n",
"model_name = \"qwen/qwen2.5-0.5b-instruct\"\n",
"input_text = \"SGLang provides efficient tokenization endpoints.\"\n",
"print_highlight(f\"Original Input Text:\\n'{input_text}'\")\n",
"\n",
"# --- tokenize the input text ---\n",
"tokenize_payload = {\n",
" \"model\": model_name,\n",
" \"prompt\": input_text,\n",
" \"add_special_tokens\": False,\n",
"}\n",
"try:\n",
" tokenize_response = requests.post(tokenize_url, json=tokenize_payload)\n",
" tokenize_response.raise_for_status()\n",
" tokenization_result = tokenize_response.json()\n",
" token_ids = tokenization_result.get(\"tokens\")\n",
"\n",
" if not token_ids:\n",
" raise ValueError(\"Tokenization returned empty tokens.\")\n",
"\n",
" print_highlight(f\"\\nTokenized Output (IDs):\\n{token_ids}\")\n",
" print_highlight(f\"Token Count: {tokenization_result.get('count')}\")\n",
" print_highlight(f\"Max Model Length: {tokenization_result.get('max_model_len')}\")\n",
"\n",
" # --- detokenize the obtained token IDs ---\n",
" detokenize_payload = {\n",
" \"model\": model_name,\n",
" \"tokens\": token_ids,\n",
" \"skip_special_tokens\": True,\n",
" }\n",
"\n",
" detokenize_response = requests.post(detokenize_url, json=detokenize_payload)\n",
" detokenize_response.raise_for_status()\n",
" detokenization_result = detokenize_response.json()\n",
" reconstructed_text = detokenization_result.get(\"text\")\n",
"\n",
" print_highlight(f\"\\nDetokenized Output (Text):\\n'{reconstructed_text}'\")\n",
"\n",
" if input_text == reconstructed_text:\n",
" print_highlight(\n",
" \"\\nRound Trip Successful: Original and reconstructed text match.\"\n",
" )\n",
" else:\n",
" print_highlight(\n",
" \"\\nRound Trip Mismatch: Original and reconstructed text differ.\"\n",
" )\n",
"\n",
"except requests.exceptions.RequestException as e:\n",
" print_highlight(f\"\\nHTTP Request Error: {e}\")\n",
"except Exception as e:\n",
" print_highlight(f\"\\nAn error occurred: {e}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(tokenizer_free_server_process)"
]
}
],
"metadata": {
Expand All @@ -493,5 +593,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
48 changes: 48 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,25 @@
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest,
ErrorResponse,
ModelCard,
ModelList,
ResponsesRequest,
ScoringRequest,
TokenizeRequest,
V1RerankReqInput,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
from sglang.srt.entrypoints.openai.serving_tokenize import (
OpenAIServingDetokenize,
OpenAIServingTokenize,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
AbortReq,
Expand Down Expand Up @@ -229,6 +235,12 @@ async def lifespan(fast_api_app: FastAPI):
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
_global_state.tokenizer_manager
)
fast_api_app.state.openai_serving_tokenize = OpenAIServingTokenize(
_global_state.tokenizer_manager
)
fast_api_app.state.openai_serving_detokenize = OpenAIServingDetokenize(
_global_state.tokenizer_manager
)

server_args: ServerArgs = fast_api_app.server_args

Expand Down Expand Up @@ -1070,6 +1082,42 @@ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
)


@app.post(
"/v1/tokenize",
response_class=ORJSONResponse,
dependencies=[Depends(validate_json_request)],
)
@app.post(
"/tokenize",
response_class=ORJSONResponse,
dependencies=[Depends(validate_json_request)],
include_in_schema=False,
)
async def openai_v1_tokenize(request: TokenizeRequest, raw_request: Request):
"""OpenAI-compatible tokenization endpoint."""
return await raw_request.app.state.openai_serving_tokenize.handle_request(
request, raw_request
)


@app.post(
"/v1/detokenize",
response_class=ORJSONResponse,
dependencies=[Depends(validate_json_request)],
)
@app.post(
"/detokenize",
response_class=ORJSONResponse,
dependencies=[Depends(validate_json_request)],
include_in_schema=False,
)
async def openai_v1_detokenize(request: DetokenizeRequest, raw_request: Request):
"""OpenAI-compatible detokenization endpoint."""
return await raw_request.app.state.openai_serving_detokenize.handle_request(
request, raw_request
)


@app.get("/v1/models", response_class=ORJSONResponse)
async def available_models():
"""Show available models. OpenAI-compatible endpoint."""
Expand Down
38 changes: 38 additions & 0 deletions python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,50 @@ class RerankResponse(BaseModel):
meta_info: Optional[dict] = None


class TokenizeRequest(BaseModel):
"""Request schema for the /tokenize endpoint."""

model: str = DEFAULT_MODEL_NAME
prompt: Union[str, List[str]]
Copy link
Collaborator

@JustinTong0323 JustinTong0323 Aug 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we keep the batched option? cc @slin1237 @CatherineSue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep it as this is not an official OpenAI endpoint, and it directly uses tokenizer, so no performance or compatibility concerns.

add_special_tokens: bool = Field(
default=True,
description="whether to add model-specific special tokens (e.g. BOS/EOS) during encoding.",
)


class TokenizeResponse(BaseModel):
"""Response schema for the /tokenize endpoint."""

tokens: Union[List[int], List[List[int]]]
count: Union[int, List[int]]
max_model_len: int


class DetokenizeRequest(BaseModel):
"""Request schema for the /detokenize endpoint."""

model: str = DEFAULT_MODEL_NAME
tokens: Union[List[int], List[List[int]]]
skip_special_tokens: bool = Field(
default=True,
description="whether to exclude special tokens (e.g. padding or EOS) during decoding.",
)


class DetokenizeResponse(BaseModel):
"""Response schema for the /detokenize endpoint."""

text: Union[str, List[str]]


OpenAIServingRequest = Union[
ChatCompletionRequest,
CompletionRequest,
EmbeddingRequest,
ScoringRequest,
V1RerankReqInput,
TokenizeRequest,
DetokenizeRequest,
]


Expand Down
144 changes: 144 additions & 0 deletions python/sglang/srt/entrypoints/openai/serving_tokenize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import logging
from http import HTTPStatus
from typing import List, Union

from fastapi import Request

from sglang.srt.entrypoints.openai.protocol import (
DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeRequest,
TokenizeResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase

logger = logging.getLogger(__name__)


class OpenAIServingTokenize(OpenAIServingBase):
"""Handler for /v1/tokenize requests"""

def _request_id_prefix(self) -> str:
return "tok-"

def _convert_to_internal_request(
self, request: TokenizeRequest
) -> tuple[TokenizeRequest, TokenizeRequest]:
return request, request

async def _handle_non_streaming_request(
self,
adapted_request: TokenizeRequest,
request: TokenizeRequest,
raw_request: Request,
) -> Union[TokenizeResponse, ErrorResponse]:
try:
tokenizer = self.tokenizer_manager.tokenizer
max_model_len = getattr(tokenizer, "model_max_length", -1)

if isinstance(request.prompt, str):
token_ids = tokenizer.encode(
request.prompt,
add_special_tokens=request.add_special_tokens,
)
tokens = token_ids
count = len(token_ids)
elif isinstance(request.prompt, list):
token_ids_list = [
tokenizer.encode(
text, add_special_tokens=request.add_special_tokens
)
for text in request.prompt
]
tokens = token_ids_list
count = [len(ids) for ids in token_ids_list]
else:
return self.create_error_response(
f"Invalid prompt type: {type(request.prompt)}. Expected str or List[str]."
)

return TokenizeResponse(
tokens=tokens, count=count, max_model_len=max_model_len
)
except Exception as e:
logger.error("Error during tokenization", exc_info=True)
return self.create_error_response(
f"Internal server error during tokenization: {e}",
err_type="InternalServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)


class OpenAIServingDetokenize(OpenAIServingBase):
"""Handler for /v1/detokenize requests"""

def _request_id_prefix(self) -> str:
return "detok-"

def _convert_to_internal_request(
self, request: DetokenizeRequest
) -> tuple[DetokenizeRequest, DetokenizeRequest]:
return request, request

async def _handle_non_streaming_request(
self,
adapted_request: DetokenizeRequest,
request: DetokenizeRequest,
raw_request: Request,
) -> Union[DetokenizeResponse, ErrorResponse]:
try:
tokenizer = self.tokenizer_manager.tokenizer

if (
isinstance(request.tokens, list)
and request.tokens
and isinstance(request.tokens[0], int)
):
if not all(isinstance(t, int) for t in request.tokens):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Should this be better removed to _validate_request?

return self.create_error_response(
"Invalid input: 'tokens' must be a list of integers."
)
tokens_to_decode = [int(t) for t in request.tokens]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Why do we need int(t) here? I assume the above if check already makes sure tokens are int?

text = tokenizer.decode(
tokens_to_decode, skip_special_tokens=request.skip_special_tokens
)
text_out: Union[str, List[str]] = text
elif (
isinstance(request.tokens, list)
and request.tokens
and isinstance(request.tokens[0], list)
):
texts: List[str] = []
for token_list in request.tokens:
if not all(isinstance(t, int) for t in token_list):
return self.create_error_response(
f"Invalid input: Sublist in 'tokens' must contain only integers. Found: {token_list}"
)
decoded_text = tokenizer.decode(
[int(t) for t in token_list],
skip_special_tokens=request.skip_special_tokens,
)
texts.append(decoded_text)
text_out = texts
elif isinstance(request.tokens, list) and not request.tokens:
text_out = ""
else:
return self.create_error_response(
f"Invalid tokens type: {type(request.tokens)}. Expected List[int] or List[List[int]]."
)

return DetokenizeResponse(text=text_out)
except Exception as e:
logger.error("Error during detokenization", exc_info=True)
if "decode" in str(e).lower():
return self.create_error_response(
f"Error decoding tokens: {e}. Input tokens might be invalid for the model.",
err_type="DecodeError",
status_code=HTTPStatus.BAD_REQUEST,
)
return self.create_error_response(
f"Internal server error during detokenization: {e}",
err_type="InternalServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
Loading
Loading