Skip to content

Commit 754b00e

Browse files
authored
[Bugfix] Fix Mistral tool-parser regex for nested JSON (#20093)
Signed-off-by: mgoin <[email protected]>
1 parent 296ce95 commit 754b00e

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

tests/models/language/generation/test_mistral.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
1111
MistralToolCall, MistralToolParser)
1212
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
13+
from vllm.transformers_utils.tokenizer import MistralTokenizer
1314

1415
from ...utils import check_logprobs_close
1516

@@ -318,3 +319,53 @@ def test_mistral_guided_decoding(
318319
schema=SAMPLE_JSON_SCHEMA)
319320
except jsonschema.exceptions.ValidationError:
320321
pytest.fail("Generated response is not valid with JSON schema")
322+
323+
324+
def test_mistral_function_call_nested_json():
325+
"""Ensure that the function-name regex captures the entire outer-most
326+
JSON block, including nested braces."""
327+
328+
# Create a minimal stub tokenizer that provides the few attributes the
329+
# parser accesses (`version` and `get_vocab`).
330+
class _StubMistralTokenizer(MistralTokenizer):
331+
version = 11 # Satisfy the version check
332+
333+
def __init__(self):
334+
pass
335+
336+
@staticmethod
337+
def get_vocab():
338+
# Provide the special TOOL_CALLS token expected by the parser.
339+
return {"[TOOL_CALLS]": 0}
340+
341+
tokenizer = _StubMistralTokenizer()
342+
parser = MistralToolParser(tokenizer)
343+
344+
# Craft a model output featuring nested JSON inside the arguments.
345+
args_dict = {
346+
"city": "Dallas",
347+
"state": "TX",
348+
"unit": "fahrenheit",
349+
"sub_dict": {
350+
"foo": "bar",
351+
"inner": {
352+
"x": 1,
353+
"y": 2
354+
}
355+
},
356+
}
357+
358+
model_output = (
359+
f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}")
360+
361+
parsed = parser.extract_tool_calls(model_output, None)
362+
363+
# Assertions: the tool call is detected and the full nested JSON is parsed
364+
# without truncation.
365+
assert parsed.tools_called
366+
367+
assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id)
368+
assert parsed.tool_calls[0].function.name == "get_current_weather"
369+
assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
370+
# No additional content outside the tool call should be returned.
371+
assert parsed.content is None

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def __init__(self, tokenizer: AnyTokenizer):
7777
self.bot_token_id = self.vocab.get(self.bot_token)
7878
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
7979
if _is_fn_name_regex_support(self.model_tokenizer):
80-
self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
81-
re.DOTALL)
80+
self.fn_name_regex = re.compile(
81+
r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL)
8282
else:
8383
self.fn_name_regex = None
8484

0 commit comments

Comments
 (0)