Skip to content

Commit 4ee33ee

Browse files
committed
refactor: enhance GoogleClient with token usage tracking and context handling improvements
1 parent 918bb89 commit 4ee33ee

File tree

2 files changed

+147
-19
lines changed

2 files changed

+147
-19
lines changed

api/app/clients/GoogleClient.js

Lines changed: 141 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
const { google } = require('googleapis');
2+
const { concat } = require('@langchain/core/utils/stream');
23
const { ChatVertexAI } = require('@langchain/google-vertexai');
34
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
45
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
@@ -10,11 +11,13 @@ const {
1011
endpointSettings,
1112
EModelEndpoint,
1213
VisionModes,
14+
ErrorTypes,
1315
Constants,
1416
AuthKeys,
1517
} = require('librechat-data-provider');
1618
const { encodeAndFormat } = require('~/server/services/Files/images');
1719
const Tokenizer = require('~/server/services/Tokenizer');
20+
const { spendTokens } = require('~/models/spendTokens');
1821
const { getModelMaxTokens } = require('~/utils');
1922
const { sleep } = require('~/server/utils');
2023
const { logger } = require('~/config');
@@ -59,6 +62,15 @@ class GoogleClient extends BaseClient {
5962

6063
this.authHeader = options.authHeader;
6164

65+
/** @type {UsageMetadata | undefined} */
66+
this.usage;
67+
/** The key for the usage object's input tokens
68+
* @type {string} */
69+
this.inputTokensKey = 'input_tokens';
70+
/** The key for the usage object's output tokens
71+
* @type {string} */
72+
this.outputTokensKey = 'output_tokens';
73+
6274
if (options.skipSetOptions) {
6375
return;
6476
}
@@ -170,6 +182,11 @@ class GoogleClient extends BaseClient {
170182
this.completionsUrl = this.constructUrl();
171183
}
172184

185+
let promptPrefix = (this.options.promptPrefix ?? '').trim();
186+
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
187+
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
188+
}
189+
this.options.promptPrefix = promptPrefix;
173190
this.initializeClient();
174191
return this;
175192
}
@@ -322,19 +339,56 @@ class GoogleClient extends BaseClient {
322339
* @param {TMessage[]} [messages=[]]
323340
* @param {string} [parentMessageId]
324341
*/
325-
async buildMessages(messages = [], parentMessageId) {
342+
async buildMessages(_messages = [], parentMessageId) {
326343
if (!this.isGenerativeModel && !this.project_id) {
327344
throw new Error(
328345
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
329346
);
330347
}
331348

349+
if (this.options.promptPrefix) {
350+
const instructionsTokenCount = this.getTokenCount(this.options.promptPrefix);
351+
352+
this.maxContextTokens = this.maxContextTokens - instructionsTokenCount;
353+
if (this.maxContextTokens < 0) {
354+
const info = `${instructionsTokenCount} / ${this.maxContextTokens}`;
355+
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
356+
logger.warn(`Instructions token count exceeds max context (${info}).`);
357+
throw new Error(errorMessage);
358+
}
359+
}
360+
361+
for (let i = 0; i < _messages.length; i++) {
362+
const message = _messages[i];
363+
if (!message.tokenCount) {
364+
_messages[i].tokenCount = this.getTokenCountForMessage({
365+
role: message.isCreatedByUser ? 'user' : 'assistant',
366+
content: message.content ?? message.text,
367+
});
368+
}
369+
}
370+
371+
const {
372+
payload: messages,
373+
tokenCountMap,
374+
promptTokens,
375+
} = await this.handleContextStrategy({
376+
orderedMessages: _messages,
377+
formattedMessages: _messages,
378+
});
379+
332380
if (!this.project_id && !EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)) {
333-
return await this.buildGenerativeMessages(messages);
381+
const result = await this.buildGenerativeMessages(messages);
382+
result.tokenCountMap = tokenCountMap;
383+
result.promptTokens = promptTokens;
384+
return result;
334385
}
335386

336387
if (this.options.attachments && this.isGenerativeModel) {
337-
return this.buildVisionMessages(messages, parentMessageId);
388+
const result = this.buildVisionMessages(messages, parentMessageId);
389+
result.tokenCountMap = tokenCountMap;
390+
result.promptTokens = promptTokens;
391+
return result;
338392
}
339393

340394
if (this.isTextModel) {
@@ -352,17 +406,12 @@ class GoogleClient extends BaseClient {
352406
],
353407
};
354408

355-
let promptPrefix = (this.options.promptPrefix ?? '').trim();
356-
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
357-
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
358-
}
359-
360-
if (promptPrefix) {
361-
payload.instances[0].context = promptPrefix;
409+
if (this.options.promptPrefix) {
410+
payload.instances[0].context = this.options.promptPrefix;
362411
}
363412

364413
logger.debug('[GoogleClient] buildMessages', payload);
365-
return { prompt: payload };
414+
return { prompt: payload, tokenCountMap, promptTokens };
366415
}
367416

368417
async buildMessagesPrompt(messages, parentMessageId) {
@@ -405,9 +454,6 @@ class GoogleClient extends BaseClient {
405454
}
406455

407456
let promptPrefix = (this.options.promptPrefix ?? '').trim();
408-
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
409-
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
410-
}
411457

412458
if (identityPrefix) {
413459
promptPrefix = `${identityPrefix}${promptPrefix}`;
@@ -586,11 +632,7 @@ class GoogleClient extends BaseClient {
586632
generationConfig: googleGenConfigSchema.parse(this.modelOptions),
587633
};
588634

589-
let promptPrefix = (this.options.promptPrefix ?? '').trim();
590-
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
591-
promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim();
592-
}
593-
635+
const promptPrefix = (this.options.promptPrefix ?? '').trim();
594636
if (promptPrefix.length) {
595637
requestOptions.systemInstruction = {
596638
parts: [
@@ -602,20 +644,35 @@ class GoogleClient extends BaseClient {
602644
}
603645

604646
const delay = modelName.includes('flash') ? 8 : 15;
647+
/** @type {GenAIUsageMetadata} */
648+
let usageMetadata;
605649
const result = await client.generateContentStream(requestOptions);
606650
for await (const chunk of result.stream) {
651+
usageMetadata = !usageMetadata
652+
? chunk?.usageMetadata
653+
: Object.assign(usageMetadata, chunk?.usageMetadata);
607654
const chunkText = chunk.text();
608655
await this.generateTextStream(chunkText, onProgress, {
609656
delay,
610657
});
611658
reply += chunkText;
612659
await sleep(streamRate);
613660
}
661+
662+
if (usageMetadata) {
663+
this.usage = {
664+
input_tokens: usageMetadata.promptTokenCount,
665+
output_tokens: usageMetadata.candidatesTokenCount,
666+
};
667+
}
614668
return reply;
615669
}
616670

671+
/** @type {import('@langchain/core/messages').AIMessageChunk['usage_metadata']} */
672+
let usageMetadata;
617673
const stream = await this.client.stream(messages, {
618674
signal: abortController.signal,
675+
streamUsage: true,
619676
safetySettings,
620677
});
621678

@@ -631,16 +688,80 @@ class GoogleClient extends BaseClient {
631688
}
632689

633690
for await (const chunk of stream) {
691+
usageMetadata = !usageMetadata
692+
? chunk?.usage_metadata
693+
: concat(usageMetadata, chunk?.usage_metadata);
634694
const chunkText = chunk?.content ?? chunk;
635695
await this.generateTextStream(chunkText, onProgress, {
636696
delay,
637697
});
638698
reply += chunkText;
639699
}
640700

701+
if (usageMetadata) {
702+
this.usage = usageMetadata;
703+
}
641704
return reply;
642705
}
643706

707+
/**
708+
* Get stream usage as returned by this client's API response.
709+
* @returns {UsageMetadata} The stream usage object.
710+
*/
711+
getStreamUsage() {
712+
return this.usage;
713+
}
714+
715+
/**
716+
* Calculates the correct token count for the current user message based on the token count map and API usage.
717+
* Edge case: If the calculation results in a negative value, it returns the original estimate.
718+
* If revisiting a conversation with a chat history entirely composed of token estimates,
719+
* the cumulative token count going forward should become more accurate as the conversation progresses.
720+
* @param {Object} params - The parameters for the calculation.
721+
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
722+
* @param {string} params.currentMessageId - The ID of the current message to calculate.
723+
* @param {UsageMetadata} params.usage - The usage object returned by the API.
724+
* @returns {number} The correct token count for the current user message.
725+
*/
726+
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
727+
const originalEstimate = tokenCountMap[currentMessageId] || 0;
728+
729+
if (!usage || typeof usage.input_tokens !== 'number') {
730+
return originalEstimate;
731+
}
732+
733+
tokenCountMap[currentMessageId] = 0;
734+
const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => {
735+
const numCount = Number(count);
736+
return sum + (isNaN(numCount) ? 0 : numCount);
737+
}, 0);
738+
const totalInputTokens = usage.input_tokens ?? 0;
739+
const currentMessageTokens = totalInputTokens - totalTokensFromMap;
740+
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
741+
}
742+
743+
/**
744+
* @param {object} params
745+
* @param {number} params.promptTokens
746+
* @param {number} params.completionTokens
747+
* @param {UsageMetadata} [params.usage]
748+
* @param {string} [params.model]
749+
* @param {string} [params.context='message']
750+
* @returns {Promise<void>}
751+
*/
752+
async recordTokenUsage({ promptTokens, completionTokens, model, context = 'message' }) {
753+
await spendTokens(
754+
{
755+
context,
756+
user: this.user,
757+
conversationId: this.conversationId,
758+
model: model ?? this.modelOptions.model,
759+
endpointTokenConfig: this.options.endpointTokenConfig,
760+
},
761+
{ promptTokens, completionTokens },
762+
);
763+
}
764+
644765
/**
645766
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
646767
*/
@@ -726,6 +847,7 @@ class GoogleClient extends BaseClient {
726847
endpointType: null,
727848
artifacts: this.options.artifacts,
728849
promptPrefix: this.options.promptPrefix,
850+
maxContextTokens: this.options.maxContextTokens,
729851
modelLabel: this.options.modelLabel,
730852
iconURL: this.options.iconURL,
731853
greeting: this.options.greeting,

api/typedefs.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@
161161
* @memberof typedefs
162162
*/
163163

164+
/**
165+
* @exports GenAIUsageMetadata
166+
* @typedef {import('@google/generative-ai').UsageMetadata} GenAIUsageMetadata
167+
* @memberof typedefs
168+
*/
169+
164170
/**
165171
* @exports AssistantStreamEvent
166172
* @typedef {import('openai').default.Beta.AssistantStreamEvent} AssistantStreamEvent

0 commit comments

Comments
 (0)