Skip to content

Commit 9aac495

Browse files
committed
update from Add ToolParser and MoE Config for Hunyuan A13B vllm-project#20820
1 parent 5de0883 commit 9aac495

File tree

5 files changed

+646
-1
lines changed

5 files changed

+646
-1
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,11 @@ def main(args: argparse.Namespace):
585585
topk = config.num_experts_per_tok
586586
intermediate_size = config.moe_intermediate_size
587587
shard_intermediate_size = 2 * intermediate_size // args.tp_size
588+
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
589+
E = config.num_experts
590+
topk = config.moe_topk[0]
591+
intermediate_size = config.moe_intermediate_size[0]
592+
shard_intermediate_size = 2 * intermediate_size // args.tp_size
588593
else:
589594
# Support for llama4
590595
config = config.get_text_config()
@@ -741,3 +746,4 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
741746
args = parser.parse_args()
742747

743748
main(args)
749+
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
{% set loop_messages = messages %}
2+
{% if tools %}
3+
{% set weekday_map = {'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三', 'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'} %}
4+
{% set weekday_cn = weekday_map[strftime_now('%A')] %}
5+
{% set datetime_str = strftime_now('%Y-%m-%d %H:%M:%S') %}
6+
{% set datetime_str = datetime_str + ' ' + weekday_cn %}
7+
{% for message in loop_messages %}
8+
{% if 'content' in message %}
9+
{% set content = message['content'] %}
10+
{% else %}
11+
{% set content = '' %}
12+
{% endif %}
13+
{% if loop.index0 == 0 %}
14+
{% set content_tmp = '你是一位函数组合专家。你会得到一个问题和一组可能的函数。根据问题,你需要进行一个或多个函数/工具调用以实现目的。
15+
如果没有一个函数可以使用,请直接使用自然语言回复用户,以助手:开头。
16+
如果给定的问题缺少函数所需的参数,请使用自然语言进行提问,向用户询问必要信息,以助手:开头。
17+
如果调用结果已经足够回答用户问题,请对历史结果进行总结,使用自然语言回复用户,以助手:开头。
18+
你应该只在工具调用部分返回函数调用。如果你决定调用任何函数,你必须将其格式化为<tool_calls>[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]</tool_calls>。你不应该在回复中包含任何其他文本。以下是你可以调用的函数列表,格式为JSON。
19+
' %}
20+
{% set content_tmp = content_tmp + '
21+
' + tools | tojson + '
22+
' %}
23+
{% if message['role'] == 'system' %}
24+
{% set content_tmp = content_tmp + '
25+
额外要求:
26+
' + content + '
27+
28+
如果你决定返回函数调用,请将其格式化为<tool_calls>[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]</tool_calls>,不得包含其他文本。如果额外要求里有格式要求,请忽略,以此处为准。
29+
否则,请参考开头说的三种情况,以助手:开头进行回复。
30+
31+
如果额外要求里有时间信息,就以额外要求里的时间为准,否则,参考当前时间:' + datetime_str %}
32+
{% set content = '<|startoftext|>' + content_tmp + '<|extra_4|>' %}
33+
{% elif message['role'] == 'user' %}
34+
{% set content_tmp = content_tmp + '
35+
如果你决定返回函数调用,请将其格式化为<tool_calls>[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]</tool_calls>,不得包含其他文本。
36+
否则,请参考开头说的三种情况,以助手:开头进行回复。
37+
38+
当前时间:' + datetime_str %}
39+
{% set content_tmp = '<|startoftext|>' + content_tmp + '<|extra_4|>'%}
40+
{% set content = content_tmp + '用户:' + content + '<|extra_0|>' %}
41+
{% endif %}
42+
{% else %}
43+
{% if message['role'] == 'user' %}
44+
{% set content = '用户:' + content + '<|extra_0|>' %}
45+
{% elif message['role'] == 'assistant' %}
46+
{% if 'tool_calls' in message %}
47+
{% set tool_calls = message['tool_calls'] %}
48+
{% set ns = namespace(tool_calls="[") %}
49+
{% for tool_call in tool_calls %}
50+
{% set function = tool_call['function'] %}
51+
{% set name = function['name'] %}
52+
{% set ns.tool_calls = ns.tool_calls + '{"name": "' + name + '", '%}
53+
{% set arguments = function['arguments'] %}
54+
{% if arguments is not string %}
55+
{% set arguments = arguments | tojson %}
56+
{% endif %}
57+
{% set ns.tool_calls = ns.tool_calls + '"arguments": ' + arguments + '}' %}
58+
{% if not loop.last %}
59+
{% set ns.tool_calls = ns.tool_calls + ', '%}
60+
{% endif %}
61+
{% endfor %}
62+
{% set ns.tool_calls = ns.tool_calls + ']' %}
63+
{% set content = content + '<tool_calls>' + ns.tool_calls + '</tool_calls>' %}
64+
{% else %}
65+
{% set content = '助手:' + content %}
66+
{% endif %}
67+
{% set content = content + '<|eos|>' %}
68+
{% elif message['role'] == 'tool' %}
69+
{% if content is not string %}
70+
{set content = content | tojson }
71+
{% endif %}
72+
{% set content = '<tool_response>' + content + '</tool_response>' %}
73+
{% set content = content + '<|extra_0|>' %}
74+
{% endif %}
75+
{% endif %}
76+
{{- content -}}
77+
{% endfor %}
78+
{% else %}
79+
{% set context = {'has_head': true} %}
80+
{% for message in loop_messages %}
81+
{% if 'content' in message %}
82+
{% set content = message['content'] %}
83+
{% else %}
84+
{% set content = '' %}
85+
{% endif %}
86+
{% if loop.index0 == 0 %}
87+
{% if content == '' %}
88+
{% set _ = context.update({'has_head': false}) %}
89+
{% elif message['role'] == 'system' %}
90+
{% set content = '<|startoftext|>' + content + '<|extra_4|>' %}
91+
{% endif %}
92+
{% endif %}
93+
{% if message['role'] == 'user' %}
94+
{% if loop.index0 == 1 and not context.has_head %}
95+
{% set content = '<|startoftext|>' + content %}
96+
{% endif %}
97+
{% if loop.index0 == 1 and context.has_head %}
98+
{% set content = content + '<|extra_0|>' %}
99+
{% else %}
100+
{% set content = '<|startoftext|>' + content + '<|extra_0|>' %}
101+
{% endif %}
102+
{% elif message['role'] == 'assistant' %}
103+
{% set content = content + '<|eos|>' %}
104+
{% elif message['role'] == 'tool' %}
105+
{% set content = content + '<|extra_0|>' %}
106+
{% endif %}
107+
{{- content -}}
108+
{% endfor %}
109+
{% endif %}
110+
{%- if enable_thinking is defined and enable_thinking is false %}
111+
{{- '<think>\n\n</think>\n' }}
112+
{%- endif %}
113+
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# ruff: noqa: E501
4+
5+
import json
6+
from unittest.mock import MagicMock
7+
8+
import pytest
9+
10+
from tests.entrypoints.openai.tool_parsers.utils import (
11+
run_tool_extraction, run_tool_extraction_streaming)
12+
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
13+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
14+
15+
16+
def make_tool_call(name, arguments):
17+
return ToolCall(type="function",
18+
function=FunctionCall(name=name,
19+
arguments=json.dumps(arguments)))
20+
21+
22+
# TODO: add reason prefix and suffix.
23+
24+
25+
@pytest.mark.parametrize(
26+
"model_output,expected_tool_calls,expected_content",
27+
[
28+
# No tool call
29+
("How can I help you today?", [], "How can I help you today?"),
30+
# Single tool call, no content
31+
(
32+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501
33+
[
34+
make_tool_call("get_weather", {
35+
"city": "San Francisco",
36+
"metric": "celsius"
37+
})
38+
],
39+
None),
40+
# Multiple tool calls
41+
(
42+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501
43+
[
44+
make_tool_call("get_weather", {
45+
"city": "San Francisco",
46+
"metric": "celsius"
47+
}),
48+
make_tool_call(
49+
"register_user", {
50+
"name": "John Doe",
51+
"age": 37,
52+
"address": {
53+
"city": "San Francisco",
54+
"state": "CA"
55+
},
56+
"role": None,
57+
"passed_test": True,
58+
"aliases": ["John", "Johnny"]
59+
})
60+
],
61+
None),
62+
# Content before tool call
63+
(
64+
"I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501
65+
[make_tool_call("get_weather", {"city": "Boston"})],
66+
"I will call the tool now. "),
67+
# Content after tool call (should be stripped)
68+
(
69+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501
70+
[make_tool_call("get_weather", {"city": "Seattle"})],
71+
None),
72+
(
73+
"<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>",
74+
[
75+
make_tool_call(
76+
"complex_tool",
77+
{"level1": {
78+
"level2": {
79+
"level3": {
80+
"value": 123
81+
}
82+
}
83+
}})
84+
],
85+
None,
86+
),
87+
])
88+
def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
89+
expected_content):
90+
mock_tokenizer = MagicMock()
91+
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
92+
"hunyuan_a13b")(mock_tokenizer)
93+
content, tool_calls = run_tool_extraction(tool_parser,
94+
model_output,
95+
streaming=False)
96+
97+
# align the random id.
98+
for idx in range(len(tool_calls)):
99+
tool_calls[idx].id = expected_tool_calls[idx].id
100+
assert tool_calls == expected_tool_calls
101+
assert content == expected_content
102+
103+
104+
# Streaming test: simulate incremental output
105+
@pytest.mark.parametrize("model_deltas,expected_tool_calls", [
106+
([
107+
"<tool_calls>[{\"name\": \"get_weather\", ",
108+
"\"arguments\": {\"city\": \"San Francisco\", ",
109+
"\"metric\": \"celsius\"}}]", "</tool_calls>"
110+
], [
111+
make_tool_call("get_weather", {
112+
"city": "San Francisco",
113+
"metric": "celsius"
114+
})
115+
]),
116+
([
117+
"<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
118+
" {\"city\": \"Boston\"}", "}]", "</tool_calls>"
119+
], [make_tool_call("get_weather", {"city": "Boston"})]),
120+
([
121+
"", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
122+
" {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>"
123+
], [make_tool_call("get_weather", {"city": "Boston"})]),
124+
pytest.param([
125+
"<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ",
126+
" {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}",
127+
"]</tool_calls>"
128+
], [
129+
make_tool_call("complex_tool",
130+
{"level1": {
131+
"level2": {
132+
"level3": {
133+
"value": 123
134+
}
135+
}
136+
}})
137+
],
138+
marks=pytest.mark.xfail(
139+
reason="stream parsing not support nested json yet.")),
140+
])
141+
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
142+
mock_tokenizer = MagicMock()
143+
144+
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
145+
"hunyuan_a13b")(mock_tokenizer)
146+
reconstructor = run_tool_extraction_streaming(
147+
tool_parser, model_deltas, assert_one_tool_per_delta=False)
148+
149+
# align the random id.
150+
for idx in range(len(reconstructor.tool_calls)):
151+
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
152+
153+
assert reconstructor.tool_calls == expected_tool_calls

vllm/entrypoints/openai/tool_parsers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
77
from .granite_tool_parser import GraniteToolParser
88
from .hermes_tool_parser import Hermes2ProToolParser
9+
from .hunyuan_a13b_tool_parser import HunyuanA13BToolParser
910
from .internlm2_tool_parser import Internlm2ToolParser
1011
from .jamba_tool_parser import JambaToolParser
1112
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
@@ -19,5 +20,5 @@
1920
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
2021
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
2122
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
22-
"DeepSeekV3ToolParser"
23+
"DeepSeekV3ToolParser", "HunyuanA13BToolParser"
2324
]

0 commit comments

Comments
 (0)