Skip to content

Commit b466b42

Browse files
authored
💄 style: add reasoning tokens and token usage statistics for Google Gemini (lobehub#7501)
1 parent db967ff commit b466b42

File tree

7 files changed

+372
-86
lines changed

7 files changed

+372
-86
lines changed

src/config/aiModels/vertexai.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ const vertexaiChatModels: AIChatModelCard[] = [
9898
type: 'chat',
9999
},
100100
{
101-
abilities: {
102-
functionCall: true,
103-
vision: true
101+
abilities: {
102+
functionCall: true,
103+
vision: true
104104
},
105105
contextWindowTokens: 1_000_000 + 8192,
106106
description: 'Gemini 1.5 Flash 002 是一款高效的多模态模型,支持广泛应用的扩展。',
@@ -115,9 +115,9 @@ const vertexaiChatModels: AIChatModelCard[] = [
115115
type: 'chat',
116116
},
117117
{
118-
abilities: {
119-
functionCall: true,
120-
vision: true
118+
abilities: {
119+
functionCall: true,
120+
vision: true
121121
},
122122
contextWindowTokens: 2_000_000 + 8192,
123123
description:

src/libs/agent-runtime/google/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
106106

107107
const contents = await this.buildGoogleMessages(payload.messages);
108108

109+
const inputStartAt = Date.now();
109110
const geminiStreamResult = await this.client
110111
.getGenerativeModel(
111112
{
@@ -161,7 +162,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
161162

162163
// Convert the response into a friendly text-stream
163164
const Stream = this.isVertexAi ? VertexAIStream : GoogleGenerativeAIStream;
164-
const stream = Stream(prod, options?.callback);
165+
const stream = Stream(prod, { callbacks: options?.callback, inputStartAt });
165166

166167
// Respond with the stream
167168
return StreamingResponse(stream, { headers: options?.headers });

src/libs/agent-runtime/utils/streams/google-ai.test.ts

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ describe('GoogleGenerativeAIStream', () => {
3434
const onCompletionMock = vi.fn();
3535

3636
const protocolStream = GoogleGenerativeAIStream(mockGoogleStream, {
37-
onStart: onStartMock,
38-
onText: onTextMock,
39-
onToolsCalling: onToolCallMock,
40-
onCompletion: onCompletionMock,
37+
callbacks: {
38+
onStart: onStartMock,
39+
onText: onTextMock,
40+
onToolsCalling: onToolCallMock,
41+
onCompletion: onCompletionMock,
42+
},
4143
});
4244

4345
const decoder = new TextDecoder();
@@ -187,7 +189,7 @@ describe('GoogleGenerativeAIStream', () => {
187189
// usage
188190
'id: chat_1\n',
189191
'event: usage\n',
190-
`data: {"inputImageTokens":258,"inputTextTokens":8,"totalInputTokens":266,"totalTokens":266}\n\n`,
192+
`data: {"inputImageTokens":258,"inputTextTokens":8,"outputTextTokens":0,"totalInputTokens":266,"totalOutputTokens":0,"totalTokens":266}\n\n`,
191193
]);
192194
});
193195

@@ -276,7 +278,100 @@ describe('GoogleGenerativeAIStream', () => {
276278
// usage
277279
'id: chat_1',
278280
'event: usage',
279-
`data: {"inputTextTokens":19,"totalInputTokens":19,"totalOutputTokens":11,"totalTokens":30}\n`,
281+
`data: {"inputTextTokens":19,"outputTextTokens":11,"totalInputTokens":19,"totalOutputTokens":11,"totalTokens":30}\n`,
282+
].map((i) => i + '\n'),
283+
);
284+
});
285+
286+
it('should handle stop with content and thought', async () => {
287+
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('1');
288+
289+
const data = [
290+
{
291+
candidates: [
292+
{
293+
content: { parts: [{ text: '234' }], role: 'model' },
294+
safetyRatings: [
295+
{ category: 'HARM_CATEGORY_HATE_SPEECH', probability: 'NEGLIGIBLE' },
296+
{ category: 'HARM_CATEGORY_DANGEROUS_CONTENT', probability: 'NEGLIGIBLE' },
297+
{ category: 'HARM_CATEGORY_HARASSMENT', probability: 'NEGLIGIBLE' },
298+
{ category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', probability: 'NEGLIGIBLE' },
299+
],
300+
},
301+
],
302+
text: () => '234',
303+
usageMetadata: {
304+
promptTokenCount: 19,
305+
candidatesTokenCount: 3,
306+
totalTokenCount: 122,
307+
promptTokensDetails: [{ modality: 'TEXT', tokenCount: 19 }],
308+
thoughtsTokenCount: 100,
309+
},
310+
modelVersion: 'gemini-2.0-flash-exp-image-generation',
311+
},
312+
{
313+
text: () => '567890\n',
314+
candidates: [
315+
{
316+
content: { parts: [{ text: '567890\n' }], role: 'model' },
317+
finishReason: 'STOP',
318+
safetyRatings: [
319+
{ category: 'HARM_CATEGORY_HATE_SPEECH', probability: 'NEGLIGIBLE' },
320+
{ category: 'HARM_CATEGORY_DANGEROUS_CONTENT', probability: 'NEGLIGIBLE' },
321+
{ category: 'HARM_CATEGORY_HARASSMENT', probability: 'NEGLIGIBLE' },
322+
{ category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', probability: 'NEGLIGIBLE' },
323+
],
324+
},
325+
],
326+
usageMetadata: {
327+
promptTokenCount: 19,
328+
candidatesTokenCount: 11,
329+
totalTokenCount: 131,
330+
promptTokensDetails: [{ modality: 'TEXT', tokenCount: 19 }],
331+
candidatesTokensDetails: [{ modality: 'TEXT', tokenCount: 11 }],
332+
thoughtsTokenCount: 100,
333+
},
334+
modelVersion: 'gemini-2.0-flash-exp-image-generation',
335+
},
336+
];
337+
338+
const mockGoogleStream = new ReadableStream({
339+
start(controller) {
340+
data.forEach((item) => {
341+
controller.enqueue(item);
342+
});
343+
344+
controller.close();
345+
},
346+
});
347+
348+
const protocolStream = GoogleGenerativeAIStream(mockGoogleStream);
349+
350+
const decoder = new TextDecoder();
351+
const chunks = [];
352+
353+
// @ts-ignore
354+
for await (const chunk of protocolStream) {
355+
chunks.push(decoder.decode(chunk, { stream: true }));
356+
}
357+
358+
expect(chunks).toEqual(
359+
[
360+
'id: chat_1',
361+
'event: text',
362+
'data: "234"\n',
363+
364+
'id: chat_1',
365+
'event: text',
366+
`data: "567890\\n"\n`,
367+
// stop
368+
'id: chat_1',
369+
'event: stop',
370+
`data: "STOP"\n`,
371+
// usage
372+
'id: chat_1',
373+
'event: usage',
374+
`data: {"inputTextTokens":19,"outputReasoningTokens":100,"outputTextTokens":11,"totalInputTokens":19,"totalOutputTokens":111,"totalTokens":131}\n`,
280375
].map((i) => i + '\n'),
281376
);
282377
});

src/libs/agent-runtime/utils/streams/google-ai.ts

Lines changed: 62 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
StreamToolCallChunkData,
1212
createCallbacksTransformer,
1313
createSSEProtocolTransformer,
14+
createTokenSpeedCalculator,
1415
generateToolCallId,
1516
} from './protocol';
1617

@@ -19,31 +20,62 @@ const transformGoogleGenerativeAIStream = (
1920
context: StreamContext,
2021
): StreamProtocolChunk | StreamProtocolChunk[] => {
2122
// maybe need another structure to add support for multiple choices
23+
const candidate = chunk.candidates?.[0];
24+
const usage = chunk.usageMetadata;
25+
const usageChunks: StreamProtocolChunk[] = [];
26+
if (candidate?.finishReason && usage) {
27+
const outputReasoningTokens = (usage as any).thoughtsTokenCount || undefined;
28+
const totalOutputTokens = (usage.candidatesTokenCount ?? 0) + (outputReasoningTokens ?? 0);
29+
30+
usageChunks.push(
31+
{ data: candidate.finishReason, id: context?.id, type: 'stop' },
32+
{
33+
data: {
34+
// TODO: Google SDK 0.24.0 don't have promptTokensDetails types
35+
inputImageTokens: (usage as any).promptTokensDetails?.find(
36+
(i: any) => i.modality === 'IMAGE',
37+
)?.tokenCount,
38+
inputTextTokens: (usage as any).promptTokensDetails?.find(
39+
(i: any) => i.modality === 'TEXT',
40+
)?.tokenCount,
41+
outputReasoningTokens,
42+
outputTextTokens: totalOutputTokens - (outputReasoningTokens ?? 0),
43+
totalInputTokens: usage.promptTokenCount,
44+
totalOutputTokens,
45+
totalTokens: usage.totalTokenCount,
46+
} as ModelTokensUsage,
47+
id: context?.id,
48+
type: 'usage',
49+
},
50+
);
51+
}
52+
2253
const functionCalls = chunk.functionCalls?.();
2354

2455
if (functionCalls) {
25-
return {
26-
data: functionCalls.map(
27-
(value, index): StreamToolCallChunkData => ({
28-
function: {
29-
arguments: JSON.stringify(value.args),
30-
name: value.name,
31-
},
32-
id: generateToolCallId(index, value.name),
33-
index: index,
34-
type: 'function',
35-
}),
36-
),
37-
id: context.id,
38-
type: 'tool_calls',
39-
};
56+
return [
57+
{
58+
data: functionCalls.map(
59+
(value, index): StreamToolCallChunkData => ({
60+
function: {
61+
arguments: JSON.stringify(value.args),
62+
name: value.name,
63+
},
64+
id: generateToolCallId(index, value.name),
65+
index: index,
66+
type: 'function',
67+
}),
68+
),
69+
id: context.id,
70+
type: 'tool_calls',
71+
},
72+
...usageChunks,
73+
];
4074
}
4175

4276
const text = chunk.text?.();
4377

44-
if (chunk.candidates) {
45-
const candidate = chunk.candidates[0];
46-
78+
if (candidate) {
4779
// return the grounding
4880
if (candidate.groundingMetadata) {
4981
const { webSearchQueries, groundingChunks } = candidate.groundingMetadata;
@@ -64,31 +96,15 @@ const transformGoogleGenerativeAIStream = (
6496
id: context.id,
6597
type: 'grounding',
6698
},
99+
...usageChunks,
67100
];
68101
}
69102

70103
if (candidate.finishReason) {
71104
if (chunk.usageMetadata) {
72-
const usage = chunk.usageMetadata;
73105
return [
74106
!!text ? { data: text, id: context?.id, type: 'text' } : undefined,
75-
{ data: candidate.finishReason, id: context?.id, type: 'stop' },
76-
{
77-
data: {
78-
// TODO: Google SDK 0.24.0 don't have promptTokensDetails types
79-
inputImageTokens: (usage as any).promptTokensDetails?.find(
80-
(i: any) => i.modality === 'IMAGE',
81-
)?.tokenCount,
82-
inputTextTokens: (usage as any).promptTokensDetails?.find(
83-
(i: any) => i.modality === 'TEXT',
84-
)?.tokenCount,
85-
totalInputTokens: usage.promptTokenCount,
86-
totalOutputTokens: usage.candidatesTokenCount,
87-
totalTokens: usage.totalTokenCount,
88-
} as ModelTokensUsage,
89-
id: context?.id,
90-
type: 'usage',
91-
},
107+
...usageChunks,
92108
].filter(Boolean) as StreamProtocolChunk[];
93109
}
94110
return { data: candidate.finishReason, id: context?.id, type: 'stop' };
@@ -117,13 +133,21 @@ const transformGoogleGenerativeAIStream = (
117133
};
118134
};
119135

136+
export interface GoogleAIStreamOptions {
137+
callbacks?: ChatStreamCallbacks;
138+
inputStartAt?: number;
139+
}
140+
120141
export const GoogleGenerativeAIStream = (
121142
rawStream: ReadableStream<EnhancedGenerateContentResponse>,
122-
callbacks?: ChatStreamCallbacks,
143+
{ callbacks, inputStartAt }: GoogleAIStreamOptions = {},
123144
) => {
124145
const streamStack: StreamContext = { id: 'chat_' + nanoid() };
125146

126147
return rawStream
127-
.pipeThrough(createSSEProtocolTransformer(transformGoogleGenerativeAIStream, streamStack))
148+
.pipeThrough(
149+
createTokenSpeedCalculator(transformGoogleGenerativeAIStream, { inputStartAt, streamStack }),
150+
)
151+
.pipeThrough(createSSEProtocolTransformer((c) => c, streamStack))
128152
.pipeThrough(createCallbacksTransformer(callbacks));
129153
};

src/libs/agent-runtime/utils/streams/protocol.ts

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,17 +298,37 @@ export const TOKEN_SPEED_CHUNK_ID = 'output_speed';
298298
*/
299299
export const createTokenSpeedCalculator = (
300300
transformer: (chunk: any, stack: StreamContext) => StreamProtocolChunk | StreamProtocolChunk[],
301-
{ streamStack, inputStartAt }: { inputStartAt?: number; streamStack?: StreamContext } = {},
301+
{ inputStartAt, streamStack }: { inputStartAt?: number; streamStack?: StreamContext } = {},
302302
) => {
303303
let outputStartAt: number | undefined;
304+
let outputThinking: boolean | undefined;
304305

305306
const process = (chunk: StreamProtocolChunk) => {
306307
let result = [chunk];
307-
// if the chunk is the first text chunk, set as output start
308-
if (!outputStartAt && chunk.type === 'text') outputStartAt = Date.now();
308+
// if the chunk is the first text or reasoning chunk, set as output start
309+
if (!outputStartAt && (chunk.type === 'text' || chunk.type === 'reasoning')) {
310+
outputStartAt = Date.now();
311+
}
312+
313+
/**
314+
* 部分 provider 在正式输出 reasoning 前,可能会先输出 content 为空字符串的 chunk,
315+
* 其中 reasoning 可能为 null,会导致判断是否输出思考内容错误,所以过滤掉 null 或者空字符串。
316+
* 也可能是某些特殊 token,所以不修改 outputStartAt 的逻辑。
317+
*/
318+
if (
319+
outputThinking === undefined &&
320+
(chunk.type === 'text' || chunk.type === 'reasoning') &&
321+
typeof chunk.data === 'string' &&
322+
chunk.data.length > 0
323+
) {
324+
outputThinking = chunk.type === 'reasoning';
325+
}
309326
// if the chunk is the stop chunk, set as output finish
310327
if (inputStartAt && outputStartAt && chunk.type === 'usage') {
311-
const outputTokens = chunk.data?.totalOutputTokens || chunk.data?.outputTextTokens;
328+
const totalOutputTokens = chunk.data?.totalOutputTokens || chunk.data?.outputTextTokens;
329+
const reasoningTokens = chunk.data?.outputReasoningTokens || 0;
330+
const outputTokens =
331+
(outputThinking ?? false) ? totalOutputTokens : totalOutputTokens - reasoningTokens;
312332
result.push({
313333
data: {
314334
tps: (outputTokens / (Date.now() - outputStartAt)) * 1000,

0 commit comments

Comments
 (0)