|
29 | 29 | OpenAIServing,
|
30 | 30 | PromptAdapterPath,
|
31 | 31 | TextTokensPrompt)
|
32 |
| -from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, |
33 |
| - Llama3JsonToolParser, |
34 |
| - MistralToolParser, |
35 |
| - ToolParser) |
| 32 | +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager |
36 | 33 | from vllm.inputs import TokensPrompt
|
37 | 34 | from vllm.logger import init_logger
|
38 | 35 | from vllm.outputs import CompletionOutput, RequestOutput
|
@@ -82,15 +79,13 @@ def __init__(self,
|
82 | 79 |
|
83 | 80 | self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
84 | 81 | if self.enable_auto_tools:
|
85 |
| - if tool_parser == "mistral": |
86 |
| - self.tool_parser = MistralToolParser |
87 |
| - elif tool_parser == "hermes": |
88 |
| - self.tool_parser = Hermes2ProToolParser |
89 |
| - elif tool_parser == "llama3_json": |
90 |
| - self.tool_parser = Llama3JsonToolParser |
91 |
| - else: |
| 82 | + try: |
| 83 | + self.tool_parser = ToolParserManager.get_tool_parser( |
| 84 | + tool_parser) |
| 85 | + except Exception as e: |
92 | 86 | raise TypeError("Error: --enable-auto-tool-choice requires "
|
93 |
| - "--tool-call-parser") |
| 87 | + f"tool_parser:'{tool_parser}' which has not " |
| 88 | + "been registered") from e |
94 | 89 |
|
95 | 90 | async def create_chat_completion(
|
96 | 91 | self,
|
@@ -187,6 +182,10 @@ async def create_chat_completion(
|
187 | 182 | raw_request.state.request_metadata = request_metadata
|
188 | 183 |
|
189 | 184 | try:
|
| 185 | + if self.enable_auto_tools and self.tool_parser: |
| 186 | + request = self.tool_parser(tokenizer).adjust_request( |
| 187 | + request=request) |
| 188 | + |
190 | 189 | if isinstance(prompt, str):
|
191 | 190 | prompt_inputs = self._tokenize_prompt_input(
|
192 | 191 | request,
|
@@ -282,11 +281,11 @@ async def chat_completion_stream_generator(
|
282 | 281 | num_choices = 1 if request.n is None else request.n
|
283 | 282 | previous_num_tokens = [0] * num_choices
|
284 | 283 | finish_reason_sent = [False] * num_choices
|
285 |
| - |
286 | 284 | num_prompt_tokens = 0
|
287 | 285 |
|
288 |
| - tool_parser: Optional[ToolParser] = self.tool_parser( |
289 |
| - tokenizer) if self.tool_parser else None |
| 286 | + tool_parsers: List[Optional[ToolParser]] = [ |
| 287 | + self.tool_parser(tokenizer) if self.tool_parser else None |
| 288 | + ] * num_choices |
290 | 289 |
|
291 | 290 | if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
292 | 291 | tool_choice_function_name = request.tool_choice.function.name
|
@@ -324,7 +323,7 @@ async def chat_completion_stream_generator(
|
324 | 323 | # NOTE num_choices defaults to 1 so this usually executes
|
325 | 324 | # once per request
|
326 | 325 | for i in range(num_choices):
|
327 |
| - |
| 326 | + tool_parser = tool_parsers[i] |
328 | 327 | choice_data = ChatCompletionResponseStreamChoice(
|
329 | 328 | index=i,
|
330 | 329 | delta=DeltaMessage(
|
@@ -399,6 +398,7 @@ async def chat_completion_stream_generator(
|
399 | 398 |
|
400 | 399 | for output in res.outputs:
|
401 | 400 | i = output.index
|
| 401 | + tool_parser = tool_parsers[i] |
402 | 402 |
|
403 | 403 | if finish_reason_sent[i]:
|
404 | 404 | continue
|
@@ -446,7 +446,8 @@ async def chat_completion_stream_generator(
|
446 | 446 | delta_text=delta_text,
|
447 | 447 | previous_token_ids=previous_token_ids,
|
448 | 448 | current_token_ids=current_token_ids,
|
449 |
| - delta_token_ids=output.token_ids)) |
| 449 | + delta_token_ids=output.token_ids, |
| 450 | + request=request)) |
450 | 451 |
|
451 | 452 | # update the previous values for the next iteration
|
452 | 453 | previous_texts[i] = current_text
|
@@ -685,7 +686,8 @@ async def chat_completion_full_generator(
|
685 | 686 | and self.tool_parser:
|
686 | 687 |
|
687 | 688 | tool_parser = self.tool_parser(tokenizer)
|
688 |
| - tool_call_info = tool_parser.extract_tool_calls(output.text) |
| 689 | + tool_call_info = tool_parser.extract_tool_calls( |
| 690 | + output.text, request=request) |
689 | 691 | tools_called = tool_call_info.tools_called
|
690 | 692 | if tool_call_info.tools_called:
|
691 | 693 | message = ChatMessage(role=role,
|
|
0 commit comments