@@ -19,10 +19,8 @@ import (
19
19
type chatResponse interface {
20
20
// checkFunctionCall checks if the response is a function call response
21
21
checkFunctionCall () (bool , error )
22
- // getToolCalls returns the tool calls in the response
23
- getToolCalls () []openai.ToolCall
24
- // getUsage returns the usage in the response
25
- getUsage () openai.Usage
22
+ // getToolCalls returns the tool calls and its usage in the response
23
+ getToolCalls () ([]openai.ToolCall , openai.Usage )
26
24
// writeResponse defines how to write the response to the response writer
27
25
writeResponse (w EventResponseWriter , chatCtx * chatContext ) error
28
26
}
@@ -66,7 +64,7 @@ func createChatCompletions(ctx context.Context, p provider.LLMProvider, req open
66
64
type chatContext struct {
67
65
// callTimes is the number of times for calling chat completions
68
66
callTimes int
69
- // totalUsage is the total usage of all chat completions
67
+ // totalUsage is the total usage of all chat completions (include tool calls and final response)
70
68
totalUsage openai.Usage
71
69
req openai.ChatCompletionRequest
72
70
}
@@ -131,7 +129,7 @@ func multiTurnFunctionCalling(
131
129
}
132
130
if isFunctionCall {
133
131
callCtx , callSpan := tracer .Start (ctx , fmt .Sprintf ("call_functions(#%d)" , chatCtx .callTimes + 1 ))
134
- toolCalls := resp .getToolCalls ()
132
+ toolCalls , usage := resp .getToolCalls ()
135
133
136
134
// add toolCallID if toolCallID is empty
137
135
for i , call := range toolCalls {
@@ -168,7 +166,7 @@ func multiTurnFunctionCalling(
168
166
})
169
167
}
170
168
171
- updateTotalUsage (chatCtx , resp )
169
+ updateCtxUsage (chatCtx , usage )
172
170
173
171
chatCtx .callTimes ++
174
172
continue
@@ -180,8 +178,6 @@ func multiTurnFunctionCalling(
180
178
}
181
179
182
180
func endCall (ctx context.Context , chatCtx * chatContext , resp chatResponse , w EventResponseWriter , tracer trace.Tracer ) error {
183
- updateTotalUsage (chatCtx , resp )
184
-
185
181
_ , respSpan := tracer .Start (ctx , "chat_completions_response" )
186
182
defer respSpan .End ()
187
183
@@ -193,8 +189,7 @@ func endCall(ctx context.Context, chatCtx *chatContext, resp chatResponse, w Eve
193
189
return nil
194
190
}
195
191
196
- func updateTotalUsage (chatCtx * chatContext , resp chatResponse ) {
197
- usage := resp .getUsage ()
192
+ func updateCtxUsage (chatCtx * chatContext , usage openai.Usage ) {
198
193
chatCtx .totalUsage .PromptTokens += usage .PromptTokens
199
194
chatCtx .totalUsage .CompletionTokens += usage .CompletionTokens
200
195
chatCtx .totalUsage .TotalTokens += usage .TotalTokens
@@ -232,16 +227,16 @@ func (c *chatResp) checkFunctionCall() (bool, error) {
232
227
return isFunctionCall , nil
233
228
}
234
229
235
- func (c * chatResp ) getToolCalls () []openai.ToolCall {
236
- return c .resp .Choices [0 ].Message .ToolCalls
237
- }
230
+ func (c * chatResp ) getToolCalls () ([]openai.ToolCall , openai.Usage ) {
231
+ originalToolCalls := c .resp .Choices [0 ].Message .ToolCalls
232
+ copiedToolCalls := make ([]openai.ToolCall , len (originalToolCalls ))
233
+ copy (copiedToolCalls , originalToolCalls )
238
234
239
- func (c * chatResp ) getUsage () openai.Usage {
240
- return c .resp .Usage
235
+ return copiedToolCalls , c .resp .Usage
241
236
}
242
237
243
238
func (c * chatResp ) writeResponse (w EventResponseWriter , chatCtx * chatContext ) error {
244
- // accumulate usage before responding, so use `=` rather than `+=`
239
+ updateCtxUsage ( chatCtx , c . resp . Usage )
245
240
c .resp .Usage = chatCtx .totalUsage
246
241
247
242
w .Header ().Set ("Content-Type" , "application/json" )
@@ -253,7 +248,6 @@ type streamChatResp struct {
253
248
// buffer is the response be buffered before check if it is a function call
254
249
buffer []openai.ChatCompletionStreamResponse
255
250
recver provider.ResponseRecver
256
- usage openai.Usage
257
251
toolCallsMap map [int ]openai.ToolCall
258
252
}
259
253
@@ -293,21 +287,22 @@ func (resp *streamChatResp) checkFunctionCall() (bool, error) {
293
287
}
294
288
}
295
289
296
- func (r * streamChatResp ) getToolCalls () []openai.ToolCall {
290
+ func (r * streamChatResp ) getToolCalls () ( []openai.ToolCall , openai. Usage ) {
297
291
for _ , resp := range r .buffer {
298
292
if len (resp .Choices ) > 0 {
299
293
r .accuamulate (resp .Choices [0 ].Delta .ToolCalls )
300
294
}
301
295
}
302
296
297
+ usage := openai.Usage {}
303
298
for {
304
299
resp , err := r .recver .Recv ()
305
300
if err != nil {
306
301
break
307
302
}
308
303
309
304
if resp .Usage != nil {
310
- r . usage = * resp .Usage
305
+ usage = * resp .Usage
311
306
}
312
307
313
308
if len (resp .Choices ) > 0 {
@@ -319,11 +314,7 @@ func (r *streamChatResp) getToolCalls() []openai.ToolCall {
319
314
for k , v := range r .toolCallsMap {
320
315
toolCalls [k ] = v
321
316
}
322
- return toolCalls
323
- }
324
-
325
- func (r * streamChatResp ) getUsage () openai.Usage {
326
- return r .usage
317
+ return toolCalls , usage
327
318
}
328
319
329
320
func (s * streamChatResp ) writeResponse (w EventResponseWriter , chatCtx * chatContext ) error {
@@ -347,6 +338,7 @@ func (s *streamChatResp) writeResponse(w EventResponseWriter, chatCtx *chatConte
347
338
}
348
339
// response total usage when usage is not nil
349
340
if resp .Usage != nil {
341
+ updateCtxUsage (chatCtx , * resp .Usage )
350
342
resp .Usage = & chatCtx .totalUsage
351
343
}
352
344
if err := w .WriteStreamEvent (resp ); err != nil {
@@ -389,8 +381,9 @@ type invokeResp struct {
389
381
var _ chatResponse = (* invokeResp )(nil )
390
382
391
383
func (i * invokeResp ) checkFunctionCall () (bool , error ) { return i .underlying .checkFunctionCall () }
392
- func (i * invokeResp ) getToolCalls () []openai.ToolCall { return i .underlying .getToolCalls () }
393
- func (i * invokeResp ) getUsage () openai.Usage { return i .underlying .getUsage () }
384
+ func (i * invokeResp ) getToolCalls () ([]openai.ToolCall , openai.Usage ) {
385
+ return i .underlying .getToolCalls ()
386
+ }
394
387
395
388
func newInvokeResp (resp openai.ChatCompletionResponse , includeCallStack bool ) * invokeResp {
396
389
return & invokeResp {
0 commit comments