|
3 | 3 | import json
|
4 | 4 | import logging
|
5 | 5 | from abc import ABC, abstractmethod
|
| 6 | +from collections.abc import Sequence |
6 | 7 | from typing import TYPE_CHECKING, Union
|
7 | 8 |
|
8 | 9 | from openai_harmony import Author, Message, Role, StreamState, TextContent
|
@@ -67,15 +68,27 @@ def __init__(
|
67 | 68 |
|
68 | 69 | self.parser = get_streamable_parser_for_assistant()
|
69 | 70 | self.num_init_messages = len(messages)
|
70 |
| - # TODO(woosuk): Implement the following fields. |
71 | 71 | self.num_prompt_tokens = 0
|
72 |
| - self.num_cached_tokens = 0 |
73 | 72 | self.num_output_tokens = 0
|
| 73 | + # TODO(woosuk): Implement the following fields. |
| 74 | + self.num_cached_tokens = 0 |
74 | 75 | self.num_reasoning_tokens = 0
|
75 | 76 |
|
| 77 | + def _update_num_prompt_tokens(self, output: RequestOutput): |
| 78 | + if output.prompt_token_ids and len(output.prompt_token_ids) > 0: |
| 79 | + # NOTE: with built-in tools, there might be multiple rounds in |
| 80 | + # the conversation, with the full conversation being resent |
| 81 | + # as new prompt each time. Hence the sum. |
| 82 | + self.num_prompt_tokens += len(output.prompt_token_ids) |
| 83 | + |
| 84 | + def _update_num_output_tokens(self, token_ids: Sequence[int]): |
| 85 | + self.num_output_tokens += len(token_ids) |
| 86 | + |
76 | 87 | def append_output(self, output) -> None:
|
77 | 88 | if isinstance(output, RequestOutput):
|
| 89 | + self._update_num_prompt_tokens(output) |
78 | 90 | output_token_ids = output.outputs[0].token_ids
|
| 91 | + self._update_num_output_tokens(output_token_ids) |
79 | 92 | self.parser = get_streamable_parser_for_assistant()
|
80 | 93 | for token_id in output_token_ids:
|
81 | 94 | self.parser.process(token_id)
|
@@ -158,15 +171,26 @@ def __init__(self, *args, **kwargs):
|
158 | 171 | self.parser = get_streamable_parser_for_assistant()
|
159 | 172 | self.encoding = get_encoding()
|
160 | 173 | self.last_tok = None
|
| 174 | + self.first_tok_of_message = True |
161 | 175 |
|
162 | 176 | @property
|
163 | 177 | def messages(self) -> list:
|
164 | 178 | return self.parser.messages
|
165 | 179 |
|
166 | 180 | def append_output(self, output) -> None:
|
167 | 181 | if isinstance(output, RequestOutput):
|
| 182 | + # append_output is called for each output token in streaming case, |
| 183 | + # so we only want to add the prompt tokens once for each message. |
| 184 | + if self.first_tok_of_message: |
| 185 | + self._update_num_prompt_tokens(output) |
| 186 | + # Reset self.first_tok_of_message if needed: |
| 187 | + # if the current token is the last one of the current message |
| 188 | + # (finished=True), then the next token processed will mark the |
| 189 | + # beginning of a new message |
| 190 | + self.first_tok_of_message = output.finished |
168 | 191 | tok = output.outputs[0].token_ids[0]
|
169 | 192 | self.parser.process(tok)
|
| 193 | + self._update_num_output_tokens(output.outputs[0].token_ids) |
170 | 194 | self.last_tok = tok
|
171 | 195 | else:
|
172 | 196 | # Handle the case of tool output in direct message format
|
|
0 commit comments