Skip to content

Commit 977022f

Browse files
committed
Merge branch 'main' of github.com:eth-easl/Scratchpad
2 parents 6bf03d2 + 52af4e9 commit 977022f

File tree

3 files changed

+109
-2
lines changed

3 files changed

+109
-2
lines changed

scratchpad/server/openai_api/handler.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
import os
18+
import copy
1819
import time
1920
import json
2021
import uuid
@@ -294,7 +295,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
294295
if not isinstance(ret, list):
295296
ret = [ret]
296297
if end_point == "/v1/chat/completions":
297-
responses = v1_chat_generate_response(request, ret, to_file=True)
298+
responses = v1_chat_generate_response(
299+
request, file_request_list, ret, to_file=True
300+
)
298301
else:
299302
responses = v1_generate_response(
300303
request, ret, tokenizer_manager, to_file=True
@@ -900,6 +903,7 @@ def v1_chat_generate_request(
900903
add_generation_prompt=True,
901904
tools=tools,
902905
)
906+
request._raw_prompt_str = templated_message
903907
if assistant_prefix:
904908
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
905909
stop = request.stop
@@ -987,6 +991,7 @@ def v1_chat_generate_request(
987991

988992
def v1_chat_generate_response(
989993
request,
994+
raw_requests,
990995
ret,
991996
to_file=False,
992997
cache_report=False,
@@ -1041,7 +1046,7 @@ def v1_chat_generate_response(
10411046

10421047
tool_calls = None
10431048
text = ret_item["text"]
1044-
1049+
raw_outputs = copy.deepcopy(text)
10451050
if isinstance(request, list):
10461051
tool_choice = request[idx].tool_choice
10471052
tools = request[idx].tools
@@ -1113,6 +1118,31 @@ def v1_chat_generate_response(
11131118
)
11141119
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
11151120
cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
1121+
raw_prompts = []
1122+
if isinstance(request, list):
1123+
for req in request:
1124+
raw_prompt = (
1125+
req._raw_prompt_str if hasattr(req, "_raw_prompt_str") else None
1126+
)
1127+
raw_prompts.append(raw_prompt)
1128+
else:
1129+
raw_prompt = (
1130+
request._raw_prompt_str if hasattr(request, "_raw_prompt_str") else None
1131+
)
1132+
raw_prompts.append(raw_prompt)
1133+
1134+
# TODO: Find a way to include the raw outputs, where special tokens are not skipped
1135+
# raw_outputs = []
1136+
# for ret_item in ret:
1137+
# raw_output = ret_item["text"] if 'text' in ret_item else None
1138+
# raw_outputs.append(raw_output)
1139+
# raw_outputs = raw_outputs[0] if len(raw_outputs) == 1 else raw_outputs
1140+
raw_prompts = raw_prompts[0] if len(raw_prompts) == 1 else raw_prompts
1141+
1142+
return_raw = False
1143+
if "return_raw" in raw_requests[0]:
1144+
return_raw = raw_requests[0]["return_raw"]
1145+
11161146
response = ChatCompletionResponse(
11171147
id=ret[0]["meta_info"]["id"],
11181148
model=request.model,
@@ -1125,6 +1155,10 @@ def v1_chat_generate_response(
11251155
{"cached_tokens": cached_tokens} if cache_report else None
11261156
),
11271157
),
1158+
raw_prompt=raw_prompts if return_raw else None,
1159+
raw_output=raw_outputs if return_raw else None,
1160+
# TODO: Find a way to include the raw outputs, where special tokens are not skipped
1161+
# raw_output=raw_outputs if debug_mode else None,
11281162
)
11291163
return response
11301164

@@ -1136,6 +1170,10 @@ async def v1_chat_completions(
11361170
request_json = await raw_request.json()
11371171
except Exception as e:
11381172
return create_error_response("Invalid request body, error: ", str(e))
1173+
if "return_raw" in request_json:
1174+
return_raw = request_json["return_raw"]
1175+
request_json["skip_special_tokens"] = False
1176+
11391177
all_requests = [ChatCompletionRequest(**request_json)]
11401178
created = int(time.time())
11411179
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
@@ -1471,6 +1509,7 @@ async def generate_stream_resp():
14711509

14721510
response = v1_chat_generate_response(
14731511
request,
1512+
[request_json],
14741513
ret,
14751514
created,
14761515
cache_report=tokenizer_manager.server_args.enable_cache_report,

scratchpad/server/openai_api/protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ class ChatCompletionResponse(BaseModel):
351351
model: str
352352
choices: List[ChatCompletionResponseChoice]
353353
usage: UsageInfo
354+
raw_prompt: Optional[str] = None
355+
raw_output: Optional[str] = None # Raw output with special tokens included
356+
# TODO: Find a way to include the raw output, where special tokens are not excluded
357+
# raw_output: Optional[str] = None
354358

355359

356360
class DeltaMessage(BaseModel):

tests/e2e/test_raw_prompt.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import requests
2+
import json
3+
4+
5+
def test_openai_compatible_api():
6+
"""Test OpenAI-compatible API with a simple chat completion request."""
7+
8+
# API configuration
9+
base_url = "http://localhost:8080/v1" # Adjust to your API endpoint
10+
api_key = "your-api-key" # Replace with actual API key if needed
11+
12+
# Headers
13+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
14+
15+
# Test data
16+
payload = {
17+
"model": "Qwen/Qwen3-8B", # Adjust model name as needed
18+
"messages": [
19+
{"role": "system", "content": "You are a helpful assistant."},
20+
{"role": "user", "content": "Hello! Can you help me test this API?"},
21+
],
22+
"max_tokens": 2048,
23+
"temperature": 0.7,
24+
"return_raw": True, # Set to True to return raw prompts
25+
}
26+
27+
try:
28+
# Make the request
29+
response = requests.post(
30+
f"{base_url}/chat/completions", headers=headers, json=payload, timeout=30
31+
)
32+
33+
# Check response status
34+
print(f"Status Code: {response.status_code}")
35+
36+
if response.status_code == 200:
37+
result = response.json()
38+
print("✅ API request successful!")
39+
print(f"Response: {json.dumps(result, indent=2)}")
40+
41+
# Validate response structure
42+
assert "choices" in result
43+
assert len(result["choices"]) > 0
44+
assert "message" in result["choices"][0]
45+
assert "content" in result["choices"][0]["message"]
46+
47+
print("✅ Response structure validation passed!")
48+
49+
else:
50+
print(f"❌ API request failed with status {response.status_code}")
51+
print(f"Response: {response.text}")
52+
53+
except requests.exceptions.RequestException as e:
54+
print(f"❌ Request error: {e}")
55+
except json.JSONDecodeError as e:
56+
print(f"❌ JSON decode error: {e}")
57+
except AssertionError as e:
58+
print(f"❌ Response validation failed: {e}")
59+
60+
61+
if __name__ == "__main__":
62+
print("🚀 Testing OpenAI-compatible API...")
63+
test_openai_compatible_api()
64+
print("\n✨ Test completed!")

0 commit comments

Comments
 (0)