Skip to content

Commit f9bc635

Browse files
authored
fix(llm-bridge): update token usage bug (#1125)
# Description fix update token usage bug
1 parent fe4244d commit f9bc635

File tree

2 files changed

+21
-28
lines changed

2 files changed

+21
-28
lines changed

pkg/bridge/ai/chat.go

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ import (
1919
type chatResponse interface {
2020
// checkFunctionCall checks if the response is a function call response
2121
checkFunctionCall() (bool, error)
22-
// getToolCalls returns the tool calls in the response
23-
getToolCalls() []openai.ToolCall
24-
// getUsage returns the usage in the response
25-
getUsage() openai.Usage
22+
// getToolCalls returns the tool calls and its usage in the response
23+
getToolCalls() ([]openai.ToolCall, openai.Usage)
2624
// writeResponse defines how to write the response to the response writer
2725
writeResponse(w EventResponseWriter, chatCtx *chatContext) error
2826
}
@@ -66,7 +64,7 @@ func createChatCompletions(ctx context.Context, p provider.LLMProvider, req open
6664
type chatContext struct {
6765
// callTimes is the number of times for calling chat completions
6866
callTimes int
69-
// totalUsage is the total usage of all chat completions
67+
// totalUsage is the total usage of all chat completions (include tool calls and final response)
7068
totalUsage openai.Usage
7169
req openai.ChatCompletionRequest
7270
}
@@ -131,7 +129,7 @@ func multiTurnFunctionCalling(
131129
}
132130
if isFunctionCall {
133131
callCtx, callSpan := tracer.Start(ctx, fmt.Sprintf("call_functions(#%d)", chatCtx.callTimes+1))
134-
toolCalls := resp.getToolCalls()
132+
toolCalls, usage := resp.getToolCalls()
135133

136134
// add toolCallID if toolCallID is empty
137135
for i, call := range toolCalls {
@@ -168,7 +166,7 @@ func multiTurnFunctionCalling(
168166
})
169167
}
170168

171-
updateTotalUsage(chatCtx, resp)
169+
updateCtxUsage(chatCtx, usage)
172170

173171
chatCtx.callTimes++
174172
continue
@@ -180,8 +178,6 @@ func multiTurnFunctionCalling(
180178
}
181179

182180
func endCall(ctx context.Context, chatCtx *chatContext, resp chatResponse, w EventResponseWriter, tracer trace.Tracer) error {
183-
updateTotalUsage(chatCtx, resp)
184-
185181
_, respSpan := tracer.Start(ctx, "chat_completions_response")
186182
defer respSpan.End()
187183

@@ -193,8 +189,7 @@ func endCall(ctx context.Context, chatCtx *chatContext, resp chatResponse, w Eve
193189
return nil
194190
}
195191

196-
func updateTotalUsage(chatCtx *chatContext, resp chatResponse) {
197-
usage := resp.getUsage()
192+
func updateCtxUsage(chatCtx *chatContext, usage openai.Usage) {
198193
chatCtx.totalUsage.PromptTokens += usage.PromptTokens
199194
chatCtx.totalUsage.CompletionTokens += usage.CompletionTokens
200195
chatCtx.totalUsage.TotalTokens += usage.TotalTokens
@@ -232,16 +227,16 @@ func (c *chatResp) checkFunctionCall() (bool, error) {
232227
return isFunctionCall, nil
233228
}
234229

235-
func (c *chatResp) getToolCalls() []openai.ToolCall {
236-
return c.resp.Choices[0].Message.ToolCalls
237-
}
230+
func (c *chatResp) getToolCalls() ([]openai.ToolCall, openai.Usage) {
231+
originalToolCalls := c.resp.Choices[0].Message.ToolCalls
232+
copiedToolCalls := make([]openai.ToolCall, len(originalToolCalls))
233+
copy(copiedToolCalls, originalToolCalls)
238234

239-
func (c *chatResp) getUsage() openai.Usage {
240-
return c.resp.Usage
235+
return copiedToolCalls, c.resp.Usage
241236
}
242237

243238
func (c *chatResp) writeResponse(w EventResponseWriter, chatCtx *chatContext) error {
244-
// accumulate usage before responding, so use `=` rather than `+=`
239+
updateCtxUsage(chatCtx, c.resp.Usage)
245240
c.resp.Usage = chatCtx.totalUsage
246241

247242
w.Header().Set("Content-Type", "application/json")
@@ -253,7 +248,6 @@ type streamChatResp struct {
253248
// buffer is the response be buffered before check if it is a function call
254249
buffer []openai.ChatCompletionStreamResponse
255250
recver provider.ResponseRecver
256-
usage openai.Usage
257251
toolCallsMap map[int]openai.ToolCall
258252
}
259253

@@ -293,21 +287,22 @@ func (resp *streamChatResp) checkFunctionCall() (bool, error) {
293287
}
294288
}
295289

296-
func (r *streamChatResp) getToolCalls() []openai.ToolCall {
290+
func (r *streamChatResp) getToolCalls() ([]openai.ToolCall, openai.Usage) {
297291
for _, resp := range r.buffer {
298292
if len(resp.Choices) > 0 {
299293
r.accuamulate(resp.Choices[0].Delta.ToolCalls)
300294
}
301295
}
302296

297+
usage := openai.Usage{}
303298
for {
304299
resp, err := r.recver.Recv()
305300
if err != nil {
306301
break
307302
}
308303

309304
if resp.Usage != nil {
310-
r.usage = *resp.Usage
305+
usage = *resp.Usage
311306
}
312307

313308
if len(resp.Choices) > 0 {
@@ -319,11 +314,7 @@ func (r *streamChatResp) getToolCalls() []openai.ToolCall {
319314
for k, v := range r.toolCallsMap {
320315
toolCalls[k] = v
321316
}
322-
return toolCalls
323-
}
324-
325-
func (r *streamChatResp) getUsage() openai.Usage {
326-
return r.usage
317+
return toolCalls, usage
327318
}
328319

329320
func (s *streamChatResp) writeResponse(w EventResponseWriter, chatCtx *chatContext) error {
@@ -347,6 +338,7 @@ func (s *streamChatResp) writeResponse(w EventResponseWriter, chatCtx *chatConte
347338
}
348339
// response total usage when usage is not nil
349340
if resp.Usage != nil {
341+
updateCtxUsage(chatCtx, *resp.Usage)
350342
resp.Usage = &chatCtx.totalUsage
351343
}
352344
if err := w.WriteStreamEvent(resp); err != nil {
@@ -389,8 +381,9 @@ type invokeResp struct {
389381
var _ chatResponse = (*invokeResp)(nil)
390382

391383
func (i *invokeResp) checkFunctionCall() (bool, error) { return i.underlying.checkFunctionCall() }
392-
func (i *invokeResp) getToolCalls() []openai.ToolCall { return i.underlying.getToolCalls() }
393-
func (i *invokeResp) getUsage() openai.Usage { return i.underlying.getUsage() }
384+
func (i *invokeResp) getToolCalls() ([]openai.ToolCall, openai.Usage) {
385+
return i.underlying.getToolCalls()
386+
}
394387

395388
func newInvokeResp(resp openai.ChatCompletionResponse, includeCallStack bool) *invokeResp {
396389
return &invokeResp{

pkg/bridge/ai/http.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func WithTracerContext(ctx context.Context, tracer trace.Tracer) context.Context
5454
return context.WithValue(ctx, tracerContextKey{}, tracer)
5555
}
5656

57-
// FromTransIDContext returns the transID from the request context
57+
// FromTracerContext returns the tracer from the request context
5858
func FromTracerContext(ctx context.Context) trace.Tracer {
5959
val, ok := ctx.Value(tracerContextKey{}).(trace.Tracer)
6060
if !ok {

0 commit comments

Comments
 (0)