Skip to content

Commit 6a72495

Browse files
alex-jw-brooksgarg-amit
authored andcommitted
[Frontend] Add Early Validation For Chat Template / Tool Call Parser (vllm-project#9151)
Signed-off-by: Alex-Brooks <[email protected]> Signed-off-by: Amit Garg <[email protected]>
1 parent b0119f5 commit 6a72495

File tree

5 files changed

+155
-72
lines changed

5 files changed

+155
-72
lines changed
Lines changed: 109 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,131 @@
11
import json
2-
import unittest
32

4-
from vllm.entrypoints.openai.cli_args import make_arg_parser
3+
import pytest
4+
5+
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
6+
validate_parsed_serve_args)
57
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
68
from vllm.utils import FlexibleArgumentParser
79

10+
from ...utils import VLLM_PATH
11+
812
LORA_MODULE = {
913
"name": "module2",
1014
"path": "/path/to/module2",
1115
"base_model_name": "llama"
1216
}
17+
CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja"
18+
assert CHATML_JINJA_PATH.exists()
1319

1420

15-
class TestLoraParserAction(unittest.TestCase):
21+
@pytest.fixture
22+
def serve_parser():
23+
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
24+
return make_arg_parser(parser)
1625

17-
def setUp(self):
18-
# Setting up argparse parser for tests
19-
parser = FlexibleArgumentParser(
20-
description="vLLM's remote OpenAI server.")
21-
self.parser = make_arg_parser(parser)
2226

23-
def test_valid_key_value_format(self):
24-
# Test old format: name=path
25-
args = self.parser.parse_args([
26-
'--lora-modules',
27-
'module1=/path/to/module1',
27+
### Tests for Lora module parsing
28+
def test_valid_key_value_format(serve_parser):
29+
# Test old format: name=path
30+
args = serve_parser.parse_args([
31+
'--lora-modules',
32+
'module1=/path/to/module1',
33+
])
34+
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
35+
assert args.lora_modules == expected
36+
37+
38+
def test_valid_json_format(serve_parser):
39+
# Test valid JSON format input
40+
args = serve_parser.parse_args([
41+
'--lora-modules',
42+
json.dumps(LORA_MODULE),
43+
])
44+
expected = [
45+
LoRAModulePath(name='module2',
46+
path='/path/to/module2',
47+
base_model_name='llama')
48+
]
49+
assert args.lora_modules == expected
50+
51+
52+
def test_invalid_json_format(serve_parser):
53+
# Test invalid JSON format input, missing closing brace
54+
with pytest.raises(SystemExit):
55+
serve_parser.parse_args([
56+
'--lora-modules', '{"name": "module3", "path": "/path/to/module3"'
2857
])
29-
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
30-
self.assertEqual(args.lora_modules, expected)
3158

32-
def test_valid_json_format(self):
33-
# Test valid JSON format input
34-
args = self.parser.parse_args([
59+
60+
def test_invalid_type_error(serve_parser):
61+
# Test type error when values are not JSON or key=value
62+
with pytest.raises(SystemExit):
63+
serve_parser.parse_args([
3564
'--lora-modules',
36-
json.dumps(LORA_MODULE),
65+
'invalid_format' # This is not JSON or key=value format
3766
])
38-
expected = [
39-
LoRAModulePath(name='module2',
40-
path='/path/to/module2',
41-
base_model_name='llama')
42-
]
43-
self.assertEqual(args.lora_modules, expected)
44-
45-
def test_invalid_json_format(self):
46-
# Test invalid JSON format input, missing closing brace
47-
with self.assertRaises(SystemExit):
48-
self.parser.parse_args([
49-
'--lora-modules',
50-
'{"name": "module3", "path": "/path/to/module3"'
51-
])
52-
53-
def test_invalid_type_error(self):
54-
# Test type error when values are not JSON or key=value
55-
with self.assertRaises(SystemExit):
56-
self.parser.parse_args([
57-
'--lora-modules',
58-
'invalid_format' # This is not JSON or key=value format
59-
])
60-
61-
def test_invalid_json_field(self):
62-
# Test valid JSON format but missing required fields
63-
with self.assertRaises(SystemExit):
64-
self.parser.parse_args([
65-
'--lora-modules',
66-
'{"name": "module4"}' # Missing required 'path' field
67-
])
68-
69-
def test_empty_values(self):
70-
# Test when no LoRA modules are provided
71-
args = self.parser.parse_args(['--lora-modules', ''])
72-
self.assertEqual(args.lora_modules, [])
73-
74-
def test_multiple_valid_inputs(self):
75-
# Test multiple valid inputs (both old and JSON format)
76-
args = self.parser.parse_args([
67+
68+
69+
def test_invalid_json_field(serve_parser):
70+
# Test valid JSON format but missing required fields
71+
with pytest.raises(SystemExit):
72+
serve_parser.parse_args([
7773
'--lora-modules',
78-
'module1=/path/to/module1',
79-
json.dumps(LORA_MODULE),
74+
'{"name": "module4"}' # Missing required 'path' field
8075
])
81-
expected = [
82-
LoRAModulePath(name='module1', path='/path/to/module1'),
83-
LoRAModulePath(name='module2',
84-
path='/path/to/module2',
85-
base_model_name='llama')
86-
]
87-
self.assertEqual(args.lora_modules, expected)
8876

8977

90-
if __name__ == '__main__':
91-
unittest.main()
78+
def test_empty_values(serve_parser):
79+
# Test when no LoRA modules are provided
80+
args = serve_parser.parse_args(['--lora-modules', ''])
81+
assert args.lora_modules == []
82+
83+
84+
def test_multiple_valid_inputs(serve_parser):
85+
# Test multiple valid inputs (both old and JSON format)
86+
args = serve_parser.parse_args([
87+
'--lora-modules',
88+
'module1=/path/to/module1',
89+
json.dumps(LORA_MODULE),
90+
])
91+
expected = [
92+
LoRAModulePath(name='module1', path='/path/to/module1'),
93+
LoRAModulePath(name='module2',
94+
path='/path/to/module2',
95+
base_model_name='llama')
96+
]
97+
assert args.lora_modules == expected
98+
99+
100+
### Tests for serve argument validation that run prior to loading
101+
def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser):
102+
"""Ensure validation fails if tool choice is enabled with no call parser"""
103+
# If we enable-auto-tool-choice, explode with no tool-call-parser
104+
args = serve_parser.parse_args(args=["--enable-auto-tool-choice"])
105+
with pytest.raises(TypeError):
106+
validate_parsed_serve_args(args)
107+
108+
109+
def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser):
110+
"""Ensure validation passes with tool choice enabled with a call parser"""
111+
args = serve_parser.parse_args(args=[
112+
"--enable-auto-tool-choice",
113+
"--tool-call-parser",
114+
"mistral",
115+
])
116+
validate_parsed_serve_args(args)
117+
118+
119+
def test_chat_template_validation_for_happy_paths(serve_parser):
120+
"""Ensure validation passes if the chat template exists"""
121+
args = serve_parser.parse_args(
122+
args=["--chat-template",
123+
CHATML_JINJA_PATH.absolute().as_posix()])
124+
validate_parsed_serve_args(args)
125+
126+
127+
def test_chat_template_validation_for_sad_paths(serve_parser):
128+
"""Ensure validation fails if the chat template doesn't exist"""
129+
args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"])
130+
with pytest.raises(ValueError):
131+
validate_parsed_serve_args(args)

vllm/entrypoints/chat_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,28 @@ def parse_audio(self, audio_url: str) -> None:
303303
self._add_placeholder(placeholder)
304304

305305

306+
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
307+
"""Raises if the provided chat template appears invalid."""
308+
if chat_template is None:
309+
return
310+
311+
elif isinstance(chat_template, Path) and not chat_template.exists():
312+
raise FileNotFoundError(
313+
"the supplied chat template path doesn't exist")
314+
315+
elif isinstance(chat_template, str):
316+
JINJA_CHARS = "{}\n"
317+
if not any(c in chat_template
318+
for c in JINJA_CHARS) and not Path(chat_template).exists():
319+
raise ValueError(
320+
f"The supplied chat template string ({chat_template}) "
321+
f"appears path-like, but doesn't exist!")
322+
323+
else:
324+
raise TypeError(
325+
f"{type(chat_template)} is not a valid chat template type")
326+
327+
306328
def load_chat_template(
307329
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
308330
if chat_template is None:

vllm/entrypoints/openai/api_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
from vllm.engine.protocol import EngineClient
3232
from vllm.entrypoints.launcher import serve_http
3333
from vllm.entrypoints.logger import RequestLogger
34-
from vllm.entrypoints.openai.cli_args import make_arg_parser
34+
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
35+
validate_parsed_serve_args)
3536
# yapf conflicts with isort for this block
3637
# yapf: disable
3738
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@@ -577,5 +578,6 @@ def signal_handler(*_) -> None:
577578
description="vLLM OpenAI-Compatible RESTful API server.")
578579
parser = make_arg_parser(parser)
579580
args = parser.parse_args()
581+
validate_parsed_serve_args(args)
580582

581583
uvloop.run(run_server(args))

vllm/entrypoints/openai/cli_args.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import List, Optional, Sequence, Union
1111

1212
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
13+
from vllm.entrypoints.chat_utils import validate_chat_template
1314
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
1415
PromptAdapterPath)
1516
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@@ -231,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
231232
return parser
232233

233234

235+
def validate_parsed_serve_args(args: argparse.Namespace):
236+
"""Quick checks for model serve args that raise prior to loading."""
237+
if hasattr(args, "subparser") and args.subparser != "serve":
238+
return
239+
240+
# Ensure that the chat template is valid; raises if it likely isn't
241+
validate_chat_template(args.chat_template)
242+
243+
# Enable auto tool needs a tool call parser to be valid
244+
if args.enable_auto_tool_choice and not args.tool_call_parser:
245+
raise TypeError("Error: --enable-auto-tool-choice requires "
246+
"--tool-call-parser")
247+
248+
234249
def create_parser_for_docs() -> FlexibleArgumentParser:
235250
parser_for_docs = FlexibleArgumentParser(
236251
prog="-m vllm.entrypoints.openai.api_server")

vllm/scripts.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from vllm.engine.arg_utils import EngineArgs
1313
from vllm.entrypoints.openai.api_server import run_server
14-
from vllm.entrypoints.openai.cli_args import make_arg_parser
14+
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
15+
validate_parsed_serve_args)
1516
from vllm.logger import init_logger
1617
from vllm.utils import FlexibleArgumentParser
1718

@@ -142,7 +143,7 @@ def main():
142143
env_setup()
143144

144145
parser = FlexibleArgumentParser(description="vLLM CLI")
145-
subparsers = parser.add_subparsers(required=True)
146+
subparsers = parser.add_subparsers(required=True, dest="subparser")
146147

147148
serve_parser = subparsers.add_parser(
148149
"serve",
@@ -186,6 +187,9 @@ def main():
186187
chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat")
187188

188189
args = parser.parse_args()
190+
if args.subparser == "serve":
191+
validate_parsed_serve_args(args)
192+
189193
# One of the sub commands should be executed.
190194
if hasattr(args, "dispatch_function"):
191195
args.dispatch_function(args)

0 commit comments

Comments
 (0)