1
1
const { google } = require ( 'googleapis' ) ;
2
+ const { concat } = require ( '@langchain/core/utils/stream' ) ;
2
3
const { ChatVertexAI } = require ( '@langchain/google-vertexai' ) ;
3
4
const { ChatGoogleGenerativeAI } = require ( '@langchain/google-genai' ) ;
4
5
const { GoogleGenerativeAI : GenAI } = require ( '@google/generative-ai' ) ;
@@ -10,11 +11,13 @@ const {
10
11
endpointSettings,
11
12
EModelEndpoint,
12
13
VisionModes,
14
+ ErrorTypes,
13
15
Constants,
14
16
AuthKeys,
15
17
} = require ( 'librechat-data-provider' ) ;
16
18
const { encodeAndFormat } = require ( '~/server/services/Files/images' ) ;
17
19
const Tokenizer = require ( '~/server/services/Tokenizer' ) ;
20
+ const { spendTokens } = require ( '~/models/spendTokens' ) ;
18
21
const { getModelMaxTokens } = require ( '~/utils' ) ;
19
22
const { sleep } = require ( '~/server/utils' ) ;
20
23
const { logger } = require ( '~/config' ) ;
@@ -59,6 +62,15 @@ class GoogleClient extends BaseClient {
59
62
60
63
this . authHeader = options . authHeader ;
61
64
65
+ /** @type {UsageMetadata | undefined } */
66
+ this . usage ;
67
+ /** The key for the usage object's input tokens
68
+ * @type {string } */
69
+ this . inputTokensKey = 'input_tokens' ;
70
+ /** The key for the usage object's output tokens
71
+ * @type {string } */
72
+ this . outputTokensKey = 'output_tokens' ;
73
+
62
74
if ( options . skipSetOptions ) {
63
75
return ;
64
76
}
@@ -170,6 +182,11 @@ class GoogleClient extends BaseClient {
170
182
this . completionsUrl = this . constructUrl ( ) ;
171
183
}
172
184
185
+ let promptPrefix = ( this . options . promptPrefix ?? '' ) . trim ( ) ;
186
+ if ( typeof this . options . artifactsPrompt === 'string' && this . options . artifactsPrompt ) {
187
+ promptPrefix = `${ promptPrefix ?? '' } \n${ this . options . artifactsPrompt } ` . trim ( ) ;
188
+ }
189
+ this . options . promptPrefix = promptPrefix ;
173
190
this . initializeClient ( ) ;
174
191
return this ;
175
192
}
@@ -322,19 +339,56 @@ class GoogleClient extends BaseClient {
322
339
* @param {TMessage[] } [messages=[]]
323
340
* @param {string } [parentMessageId]
324
341
*/
325
- async buildMessages ( messages = [ ] , parentMessageId ) {
342
+ async buildMessages ( _messages = [ ] , parentMessageId ) {
326
343
if ( ! this . isGenerativeModel && ! this . project_id ) {
327
344
throw new Error (
328
345
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)' ,
329
346
) ;
330
347
}
331
348
349
+ if ( this . options . promptPrefix ) {
350
+ const instructionsTokenCount = this . getTokenCount ( this . options . promptPrefix ) ;
351
+
352
+ this . maxContextTokens = this . maxContextTokens - instructionsTokenCount ;
353
+ if ( this . maxContextTokens < 0 ) {
354
+ const info = `${ instructionsTokenCount } / ${ this . maxContextTokens } ` ;
355
+ const errorMessage = `{ "type": "${ ErrorTypes . INPUT_LENGTH } ", "info": "${ info } " }` ;
356
+ logger . warn ( `Instructions token count exceeds max context (${ info } ).` ) ;
357
+ throw new Error ( errorMessage ) ;
358
+ }
359
+ }
360
+
361
+ for ( let i = 0 ; i < _messages . length ; i ++ ) {
362
+ const message = _messages [ i ] ;
363
+ if ( ! message . tokenCount ) {
364
+ _messages [ i ] . tokenCount = this . getTokenCountForMessage ( {
365
+ role : message . isCreatedByUser ? 'user' : 'assistant' ,
366
+ content : message . content ?? message . text ,
367
+ } ) ;
368
+ }
369
+ }
370
+
371
+ const {
372
+ payload : messages ,
373
+ tokenCountMap,
374
+ promptTokens,
375
+ } = await this . handleContextStrategy ( {
376
+ orderedMessages : _messages ,
377
+ formattedMessages : _messages ,
378
+ } ) ;
379
+
332
380
if ( ! this . project_id && ! EXCLUDED_GENAI_MODELS . test ( this . modelOptions . model ) ) {
333
- return await this . buildGenerativeMessages ( messages ) ;
381
+ const result = await this . buildGenerativeMessages ( messages ) ;
382
+ result . tokenCountMap = tokenCountMap ;
383
+ result . promptTokens = promptTokens ;
384
+ return result ;
334
385
}
335
386
336
387
if ( this . options . attachments && this . isGenerativeModel ) {
337
- return this . buildVisionMessages ( messages , parentMessageId ) ;
388
+ const result = this . buildVisionMessages ( messages , parentMessageId ) ;
389
+ result . tokenCountMap = tokenCountMap ;
390
+ result . promptTokens = promptTokens ;
391
+ return result ;
338
392
}
339
393
340
394
if ( this . isTextModel ) {
@@ -352,17 +406,12 @@ class GoogleClient extends BaseClient {
352
406
] ,
353
407
} ;
354
408
355
- let promptPrefix = ( this . options . promptPrefix ?? '' ) . trim ( ) ;
356
- if ( typeof this . options . artifactsPrompt === 'string' && this . options . artifactsPrompt ) {
357
- promptPrefix = `${ promptPrefix ?? '' } \n${ this . options . artifactsPrompt } ` . trim ( ) ;
358
- }
359
-
360
- if ( promptPrefix ) {
361
- payload . instances [ 0 ] . context = promptPrefix ;
409
+ if ( this . options . promptPrefix ) {
410
+ payload . instances [ 0 ] . context = this . options . promptPrefix ;
362
411
}
363
412
364
413
logger . debug ( '[GoogleClient] buildMessages' , payload ) ;
365
- return { prompt : payload } ;
414
+ return { prompt : payload , tokenCountMap , promptTokens } ;
366
415
}
367
416
368
417
async buildMessagesPrompt ( messages , parentMessageId ) {
@@ -405,9 +454,6 @@ class GoogleClient extends BaseClient {
405
454
}
406
455
407
456
let promptPrefix = ( this . options . promptPrefix ?? '' ) . trim ( ) ;
408
- if ( typeof this . options . artifactsPrompt === 'string' && this . options . artifactsPrompt ) {
409
- promptPrefix = `${ promptPrefix ?? '' } \n${ this . options . artifactsPrompt } ` . trim ( ) ;
410
- }
411
457
412
458
if ( identityPrefix ) {
413
459
promptPrefix = `${ identityPrefix } ${ promptPrefix } ` ;
@@ -586,11 +632,7 @@ class GoogleClient extends BaseClient {
586
632
generationConfig : googleGenConfigSchema . parse ( this . modelOptions ) ,
587
633
} ;
588
634
589
- let promptPrefix = ( this . options . promptPrefix ?? '' ) . trim ( ) ;
590
- if ( typeof this . options . artifactsPrompt === 'string' && this . options . artifactsPrompt ) {
591
- promptPrefix = `${ promptPrefix ?? '' } \n${ this . options . artifactsPrompt } ` . trim ( ) ;
592
- }
593
-
635
+ const promptPrefix = ( this . options . promptPrefix ?? '' ) . trim ( ) ;
594
636
if ( promptPrefix . length ) {
595
637
requestOptions . systemInstruction = {
596
638
parts : [
@@ -602,20 +644,35 @@ class GoogleClient extends BaseClient {
602
644
}
603
645
604
646
const delay = modelName . includes ( 'flash' ) ? 8 : 15 ;
647
+ /** @type {GenAIUsageMetadata } */
648
+ let usageMetadata ;
605
649
const result = await client . generateContentStream ( requestOptions ) ;
606
650
for await ( const chunk of result . stream ) {
651
+ usageMetadata = ! usageMetadata
652
+ ? chunk ?. usageMetadata
653
+ : Object . assign ( usageMetadata , chunk ?. usageMetadata ) ;
607
654
const chunkText = chunk . text ( ) ;
608
655
await this . generateTextStream ( chunkText , onProgress , {
609
656
delay,
610
657
} ) ;
611
658
reply += chunkText ;
612
659
await sleep ( streamRate ) ;
613
660
}
661
+
662
+ if ( usageMetadata ) {
663
+ this . usage = {
664
+ input_tokens : usageMetadata . promptTokenCount ,
665
+ output_tokens : usageMetadata . candidatesTokenCount ,
666
+ } ;
667
+ }
614
668
return reply ;
615
669
}
616
670
671
+ /** @type {import('@langchain/core/messages').AIMessageChunk['usage_metadata'] } */
672
+ let usageMetadata ;
617
673
const stream = await this . client . stream ( messages , {
618
674
signal : abortController . signal ,
675
+ streamUsage : true ,
619
676
safetySettings,
620
677
} ) ;
621
678
@@ -631,16 +688,80 @@ class GoogleClient extends BaseClient {
631
688
}
632
689
633
690
for await ( const chunk of stream ) {
691
+ usageMetadata = ! usageMetadata
692
+ ? chunk ?. usage_metadata
693
+ : concat ( usageMetadata , chunk ?. usage_metadata ) ;
634
694
const chunkText = chunk ?. content ?? chunk ;
635
695
await this . generateTextStream ( chunkText , onProgress , {
636
696
delay,
637
697
} ) ;
638
698
reply += chunkText ;
639
699
}
640
700
701
+ if ( usageMetadata ) {
702
+ this . usage = usageMetadata ;
703
+ }
641
704
return reply ;
642
705
}
643
706
707
+ /**
708
+ * Get stream usage as returned by this client's API response.
709
+ * @returns {UsageMetadata } The stream usage object.
710
+ */
711
+ getStreamUsage ( ) {
712
+ return this . usage ;
713
+ }
714
+
715
+ /**
716
+ * Calculates the correct token count for the current user message based on the token count map and API usage.
717
+ * Edge case: If the calculation results in a negative value, it returns the original estimate.
718
+ * If revisiting a conversation with a chat history entirely composed of token estimates,
719
+ * the cumulative token count going forward should become more accurate as the conversation progresses.
720
+ * @param {Object } params - The parameters for the calculation.
721
+ * @param {Record<string, number> } params.tokenCountMap - A map of message IDs to their token counts.
722
+ * @param {string } params.currentMessageId - The ID of the current message to calculate.
723
+ * @param {UsageMetadata } params.usage - The usage object returned by the API.
724
+ * @returns {number } The correct token count for the current user message.
725
+ */
726
+ calculateCurrentTokenCount ( { tokenCountMap, currentMessageId, usage } ) {
727
+ const originalEstimate = tokenCountMap [ currentMessageId ] || 0 ;
728
+
729
+ if ( ! usage || typeof usage . input_tokens !== 'number' ) {
730
+ return originalEstimate ;
731
+ }
732
+
733
+ tokenCountMap [ currentMessageId ] = 0 ;
734
+ const totalTokensFromMap = Object . values ( tokenCountMap ) . reduce ( ( sum , count ) => {
735
+ const numCount = Number ( count ) ;
736
+ return sum + ( isNaN ( numCount ) ? 0 : numCount ) ;
737
+ } , 0 ) ;
738
+ const totalInputTokens = usage . input_tokens ?? 0 ;
739
+ const currentMessageTokens = totalInputTokens - totalTokensFromMap ;
740
+ return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate ;
741
+ }
742
+
743
+ /**
744
+ * @param {object } params
745
+ * @param {number } params.promptTokens
746
+ * @param {number } params.completionTokens
747
+ * @param {UsageMetadata } [params.usage]
748
+ * @param {string } [params.model]
749
+ * @param {string } [params.context='message']
750
+ * @returns {Promise<void> }
751
+ */
752
+ async recordTokenUsage ( { promptTokens, completionTokens, model, context = 'message' } ) {
753
+ await spendTokens (
754
+ {
755
+ context,
756
+ user : this . user ,
757
+ conversationId : this . conversationId ,
758
+ model : model ?? this . modelOptions . model ,
759
+ endpointTokenConfig : this . options . endpointTokenConfig ,
760
+ } ,
761
+ { promptTokens, completionTokens } ,
762
+ ) ;
763
+ }
764
+
644
765
/**
645
766
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
646
767
*/
@@ -726,6 +847,7 @@ class GoogleClient extends BaseClient {
726
847
endpointType : null ,
727
848
artifacts : this . options . artifacts ,
728
849
promptPrefix : this . options . promptPrefix ,
850
+ maxContextTokens : this . options . maxContextTokens ,
729
851
modelLabel : this . options . modelLabel ,
730
852
iconURL : this . options . iconURL ,
731
853
greeting : this . options . greeting ,
0 commit comments