3
3
import json
4
4
import logging
5
5
from abc import ABC , abstractmethod
6
+ from collections .abc import Sequence
6
7
from typing import TYPE_CHECKING , Union
7
8
8
9
from openai_harmony import Author , Message , Role , StreamState , TextContent
@@ -66,24 +67,30 @@ def __init__(
66
67
self .tool_sessions = tool_sessions
67
68
68
69
self .parser = get_streamable_parser_for_assistant ()
70
+ self .encoding = get_encoding ()
69
71
self .num_init_messages = len (messages )
70
72
self .num_prompt_tokens = 0
71
73
self .num_output_tokens = 0
72
74
# TODO(woosuk): Implement the following fields.
73
75
self .num_cached_tokens = 0
74
76
self .num_reasoning_tokens = 0
75
77
76
- def _update_prompt_tokens (self , output : RequestOutput ):
77
- if output .prompt_token_ids and len (
78
- output .prompt_token_ids ) > 0 and self .num_prompt_tokens == 0 :
79
- self .num_prompt_tokens = len (output .prompt_token_ids )
78
+ def _update_num_prompt_tokens (self , output : RequestOutput ):
79
+ if output .prompt_token_ids and len (output .prompt_token_ids ) > 0 :
80
+ # NOTE: with built-in tools, there might be multiple rounds in
81
+ # the conversation, with the full conversation being resent
82
+ # as new prompt each time. Hence the sum.
83
+ self .num_prompt_tokens += len (output .prompt_token_ids )
84
+
85
+ def _update_num_output_tokens (self , token_ids : Sequence [int ]):
86
+ self .num_output_tokens += len (token_ids )
80
87
81
88
def append_output (self , output ) -> None :
82
89
if isinstance (output , RequestOutput ):
83
- self ._update_prompt_tokens (output )
90
+ self ._update_num_prompt_tokens (output )
84
91
output_token_ids = output .outputs [0 ].token_ids
92
+ self ._update_num_output_tokens (output_token_ids )
85
93
self .parser = get_streamable_parser_for_assistant ()
86
- self .num_output_tokens += len (output_token_ids )
87
94
for token_id in output_token_ids :
88
95
self .parser .process (token_id )
89
96
output_msgs = self .parser .messages
@@ -163,7 +170,6 @@ def __init__(self, *args, **kwargs):
163
170
self .last_output = None
164
171
165
172
self .parser = get_streamable_parser_for_assistant ()
166
- self .encoding = get_encoding ()
167
173
self .last_tok = None
168
174
169
175
@property
@@ -172,10 +178,10 @@ def messages(self) -> list:
172
178
173
179
def append_output (self , output ) -> None :
174
180
if isinstance (output , RequestOutput ):
175
- self ._update_prompt_tokens (output )
181
+ self ._update_num_prompt_tokens (output )
176
182
tok = output .outputs [0 ].token_ids [0 ]
177
183
self .parser .process (tok )
178
- self .num_output_tokens += 1
184
+ self ._update_num_output_tokens ( output . outputs [ 0 ]. token_ids )
179
185
self .last_tok = tok
180
186
else :
181
187
# Handle the case of tool output in direct message format
0 commit comments