|
10 | 10 | from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
11 | 11 | MistralToolCall, MistralToolParser)
|
12 | 12 | from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
| 13 | +from vllm.transformers_utils.tokenizer import MistralTokenizer |
13 | 14 |
|
14 | 15 | from ...utils import check_logprobs_close
|
15 | 16 |
|
@@ -318,3 +319,53 @@ def test_mistral_guided_decoding(
|
318 | 319 | schema=SAMPLE_JSON_SCHEMA)
|
319 | 320 | except jsonschema.exceptions.ValidationError:
|
320 | 321 | pytest.fail("Generated response is not valid with JSON schema")
|
| 322 | + |
| 323 | + |
| 324 | +def test_mistral_function_call_nested_json(): |
| 325 | + """Ensure that the function-name regex captures the entire outer-most |
| 326 | + JSON block, including nested braces.""" |
| 327 | + |
| 328 | + # Create a minimal stub tokenizer that provides the few attributes the |
| 329 | + # parser accesses (`version` and `get_vocab`). |
| 330 | + class _StubMistralTokenizer(MistralTokenizer): |
| 331 | + version = 11 # Satisfy the version check |
| 332 | + |
| 333 | + def __init__(self): |
| 334 | + pass |
| 335 | + |
| 336 | + @staticmethod |
| 337 | + def get_vocab(): |
| 338 | + # Provide the special TOOL_CALLS token expected by the parser. |
| 339 | + return {"[TOOL_CALLS]": 0} |
| 340 | + |
| 341 | + tokenizer = _StubMistralTokenizer() |
| 342 | + parser = MistralToolParser(tokenizer) |
| 343 | + |
| 344 | + # Craft a model output featuring nested JSON inside the arguments. |
| 345 | + args_dict = { |
| 346 | + "city": "Dallas", |
| 347 | + "state": "TX", |
| 348 | + "unit": "fahrenheit", |
| 349 | + "sub_dict": { |
| 350 | + "foo": "bar", |
| 351 | + "inner": { |
| 352 | + "x": 1, |
| 353 | + "y": 2 |
| 354 | + } |
| 355 | + }, |
| 356 | + } |
| 357 | + |
| 358 | + model_output = ( |
| 359 | + f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}") |
| 360 | + |
| 361 | + parsed = parser.extract_tool_calls(model_output, None) |
| 362 | + |
| 363 | + # Assertions: the tool call is detected and the full nested JSON is parsed |
| 364 | + # without truncation. |
| 365 | + assert parsed.tools_called |
| 366 | + |
| 367 | + assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id) |
| 368 | + assert parsed.tool_calls[0].function.name == "get_current_weather" |
| 369 | + assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict |
| 370 | + # No additional content outside the tool call should be returned. |
| 371 | + assert parsed.content is None |
0 commit comments