Skip to content

Commit 64a35e0

Browse files
committed
feat: ensure tokens generated during multi-round built-in tools are summed in the usage statistics
Signed-off-by: Guillaume Calmettes <[email protected]>
1 parent 7b1989c commit 64a35e0

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

vllm/entrypoints/context.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import logging
55
from abc import ABC, abstractmethod
6+
from collections.abc import Sequence
67
from typing import TYPE_CHECKING, Union
78

89
from openai_harmony import Author, Message, Role, StreamState, TextContent
@@ -66,24 +67,30 @@ def __init__(
6667
self.tool_sessions = tool_sessions
6768

6869
self.parser = get_streamable_parser_for_assistant()
70+
self.encoding = get_encoding()
6971
self.num_init_messages = len(messages)
7072
self.num_prompt_tokens = 0
7173
self.num_output_tokens = 0
7274
# TODO(woosuk): Implement the following fields.
7375
self.num_cached_tokens = 0
7476
self.num_reasoning_tokens = 0
7577

76-
def _update_prompt_tokens(self, output: RequestOutput):
77-
if output.prompt_token_ids and len(
78-
output.prompt_token_ids) > 0 and self.num_prompt_tokens == 0:
79-
self.num_prompt_tokens = len(output.prompt_token_ids)
78+
def _update_num_prompt_tokens(self, output: RequestOutput):
79+
if output.prompt_token_ids and len(output.prompt_token_ids) > 0:
80+
# NOTE: with built-in tools, there might be multiple rounds in
81+
# the conversation, with the full conversation being resent
82+
# as new prompt each time. Hence the sum.
83+
self.num_prompt_tokens += len(output.prompt_token_ids)
84+
85+
def _update_num_output_tokens(self, token_ids: Sequence[int]):
86+
self.num_output_tokens += len(token_ids)
8087

8188
def append_output(self, output) -> None:
8289
if isinstance(output, RequestOutput):
83-
self._update_prompt_tokens(output)
90+
self._update_num_prompt_tokens(output)
8491
output_token_ids = output.outputs[0].token_ids
92+
self._update_num_output_tokens(output_token_ids)
8593
self.parser = get_streamable_parser_for_assistant()
86-
self.num_output_tokens += len(output_token_ids)
8794
for token_id in output_token_ids:
8895
self.parser.process(token_id)
8996
output_msgs = self.parser.messages
@@ -163,7 +170,6 @@ def __init__(self, *args, **kwargs):
163170
self.last_output = None
164171

165172
self.parser = get_streamable_parser_for_assistant()
166-
self.encoding = get_encoding()
167173
self.last_tok = None
168174

169175
@property
@@ -172,10 +178,10 @@ def messages(self) -> list:
172178

173179
def append_output(self, output) -> None:
174180
if isinstance(output, RequestOutput):
175-
self._update_prompt_tokens(output)
181+
self._update_num_prompt_tokens(output)
176182
tok = output.outputs[0].token_ids[0]
177183
self.parser.process(tok)
178-
self.num_output_tokens += 1
184+
self._update_num_output_tokens(output.outputs[0].token_ids)
179185
self.last_tok = tok
180186
else:
181187
# Handle the case of tool output in direct message format

0 commit comments

Comments
 (0)