Skip to content

Commit 8d6c1f6

Browse files
danny-avilaRasukarusan
authored andcommitted
feat: Google Gemini ❇️ (danny-avila#1355)
* refactor: add gemini-pro to google Models list; use defaultModels for central model listing * refactor(SetKeyDialog): create useMultipleKeys hook to use for Azure, export `isJson` from utils, use EModelEndpoint * refactor(useUserKey): change variable names to make keyName setting more clear * refactor(FileUpload): allow passing container className string * feat(GoogleClient): Gemini support * refactor(GoogleClient): alternate stream speed for Gemini models * feat(Gemini): styling/settings configuration for Gemini * refactor(GoogleClient): substract max response tokens from max context tokens if context is above 32k (I/O max is combined between the two) * refactor(tokens): correct google max token counts and subtract max response tokens when input/output count are combined towards max context count * feat(google/initializeClient): handle both local and user_provided credentials and write tests * fix(GoogleClient): catch if credentials are undefined, handle if serviceKey is string or object correctly, handle no examples passed, throw error if not a Generative Language model and no service account JSON key is provided, throw error if it is a Generative m odel, but not google API key was provided * refactor(loadAsyncEndpoints/google): activate Google endpoint if either the service key JSON file is provided in /api/data, or a GOOGLE_KEY is defined. * docs: updated Google configuration * fix(ci): Mock import of Service Account Key JSON file (auth.json) * Update apis_and_tokens.md * feat: increase max output tokens slider for gemini pro * refactor(GoogleSettings): handle max and default maxOutputTokens on model change * chore: add sensitive redact regex * docs: add warning about data privacy * Update apis_and_tokens.md
1 parent 1c557a2 commit 8d6c1f6

File tree

37 files changed

+702
-219
lines changed

37 files changed

+702
-219
lines changed

api/app/clients/GoogleClient.js

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
const { google } = require('googleapis');
22
const { Agent, ProxyAgent } = require('undici');
33
const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
4+
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
45
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
56
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
67
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
7-
const { getResponseSender, EModelEndpoint, endpointSettings } = require('librechat-data-provider');
8+
const {
9+
getResponseSender,
10+
EModelEndpoint,
11+
endpointSettings,
12+
AuthKeys,
13+
} = require('librechat-data-provider');
814
const { getModelMaxTokens } = require('~/utils');
915
const { formatMessage } = require('./prompts');
1016
const BaseClient = require('./BaseClient');
@@ -21,11 +27,24 @@ const settings = endpointSettings[EModelEndpoint.google];
2127
class GoogleClient extends BaseClient {
2228
constructor(credentials, options = {}) {
2329
super('apiKey', options);
24-
this.credentials = credentials;
25-
this.client_email = credentials.client_email;
26-
this.project_id = credentials.project_id;
27-
this.private_key = credentials.private_key;
30+
let creds = {};
31+
32+
if (typeof credentials === 'string') {
33+
creds = JSON.parse(credentials);
34+
} else if (credentials) {
35+
creds = credentials;
36+
}
37+
38+
const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
39+
this.serviceKey =
40+
serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : serviceKey ?? {};
41+
this.client_email = this.serviceKey.client_email;
42+
this.private_key = this.serviceKey.private_key;
43+
this.project_id = this.serviceKey.project_id;
2844
this.access_token = null;
45+
46+
this.apiKey = creds[AuthKeys.GOOGLE_API_KEY];
47+
2948
if (options.skipSetOptions) {
3049
return;
3150
}
@@ -85,7 +104,7 @@ class GoogleClient extends BaseClient {
85104
this.options = options;
86105
}
87106

88-
this.options.examples = this.options.examples
107+
this.options.examples = (this.options.examples ?? [])
89108
.filter((ex) => ex)
90109
.filter((obj) => obj.input.content !== '' && obj.output.content !== '');
91110

@@ -103,15 +122,24 @@ class GoogleClient extends BaseClient {
103122
// stop: modelOptions.stop // no stop method for now
104123
};
105124

106-
this.isChatModel = this.modelOptions.model.includes('chat');
125+
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
126+
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
127+
const { isGenerativeModel } = this;
128+
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
107129
const { isChatModel } = this;
108-
this.isTextModel = !isChatModel && /code|text/.test(this.modelOptions.model);
130+
this.isTextModel =
131+
!isGenerativeModel && !isChatModel && /code|text/.test(this.modelOptions.model);
109132
const { isTextModel } = this;
110133

111134
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model, EModelEndpoint.google);
112135
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
113136
// Earlier messages will be dropped until the prompt is within the limit.
114137
this.maxResponseTokens = this.modelOptions.maxOutputTokens || settings.maxOutputTokens.default;
138+
139+
if (this.maxContextTokens > 32000) {
140+
this.maxContextTokens = this.maxContextTokens - this.maxResponseTokens;
141+
}
142+
115143
this.maxPromptTokens =
116144
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
117145

@@ -134,7 +162,7 @@ class GoogleClient extends BaseClient {
134162
this.userLabel = this.options.userLabel || 'User';
135163
this.modelLabel = this.options.modelLabel || 'Assistant';
136164

137-
if (isChatModel) {
165+
if (isChatModel || isGenerativeModel) {
138166
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
139167
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
140168
// without tripping the stop sequences, so I'm using "||>" instead.
@@ -189,6 +217,16 @@ class GoogleClient extends BaseClient {
189217
}
190218

191219
buildMessages(messages = [], parentMessageId) {
220+
if (!this.isGenerativeModel && !this.project_id) {
221+
throw new Error(
222+
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
223+
);
224+
} else if (this.isGenerativeModel && (!this.apiKey || this.apiKey === 'user_provided')) {
225+
throw new Error(
226+
'[GoogleClient] an API Key is required for Gemini models (Generative Language API)',
227+
);
228+
}
229+
192230
if (this.isTextModel) {
193231
return this.buildMessagesPrompt(messages, parentMessageId);
194232
}
@@ -398,6 +436,16 @@ class GoogleClient extends BaseClient {
398436
return res.data;
399437
}
400438

439+
createLLM(clientOptions) {
440+
if (this.isGenerativeModel) {
441+
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
442+
}
443+
444+
return this.isTextModel
445+
? new GoogleVertexAI(clientOptions)
446+
: new ChatGoogleVertexAI(clientOptions);
447+
}
448+
401449
async getCompletion(_payload, options = {}) {
402450
const { onProgress, abortController } = options;
403451
const { parameters, instances } = _payload;
@@ -408,7 +456,7 @@ class GoogleClient extends BaseClient {
408456
let clientOptions = {
409457
authOptions: {
410458
credentials: {
411-
...this.credentials,
459+
...this.serviceKey,
412460
},
413461
projectId: this.project_id,
414462
},
@@ -436,9 +484,7 @@ class GoogleClient extends BaseClient {
436484
clientOptions.examples = examples;
437485
}
438486

439-
const model = this.isTextModel
440-
? new GoogleVertexAI(clientOptions)
441-
: new ChatGoogleVertexAI(clientOptions);
487+
const model = this.createLLM(clientOptions);
442488

443489
let reply = '';
444490
const messages = this.isTextModel
@@ -457,7 +503,9 @@ class GoogleClient extends BaseClient {
457503
});
458504

459505
for await (const chunk of stream) {
460-
await this.generateTextStream(chunk?.content ?? chunk, onProgress, { delay: 7 });
506+
await this.generateTextStream(chunk?.content ?? chunk, onProgress, {
507+
delay: this.isGenerativeModel ? 12 : 8,
508+
});
461509
reply += chunk?.content ?? chunk;
462510
}
463511

api/config/parsers.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ const winston = require('winston');
33
const traverse = require('traverse');
44
const { klona } = require('klona/full');
55

6-
const sensitiveKeys = [/^sk-\w+$/];
6+
const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/];
77

88
/**
99
* Determines if a given key string is sensitive.

api/jest.config.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ module.exports = {
1010
],
1111
moduleNameMapper: {
1212
'~/(.*)': '<rootDir>/$1',
13+
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
1314
},
1415
};

api/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"@azure/search-documents": "^12.0.0",
3232
"@keyv/mongo": "^2.1.8",
3333
"@keyv/redis": "^2.8.0",
34+
"@langchain/google-genai": "^0.0.2",
3435
"axios": "^1.3.4",
3536
"bcryptjs": "^2.4.3",
3637
"cheerio": "^1.0.0-rc.12",

api/server/services/Config/loadAsyncEndpoints.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go
88
*/
99
async function loadAsyncEndpoints() {
1010
let i = 0;
11-
let key, googleUserProvides;
11+
let serviceKey, googleUserProvides;
1212
try {
13-
key = require('~/data/auth.json');
13+
serviceKey = require('~/data/auth.json');
1414
} catch (e) {
1515
if (i === 0) {
1616
i++;
@@ -33,7 +33,7 @@ async function loadAsyncEndpoints() {
3333
}
3434
const plugins = transformToolsToMap(tools);
3535

36-
const google = key || googleUserProvides ? { userProvide: googleUserProvides } : false;
36+
const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;
3737

3838
const gptPlugins =
3939
openAIApiKey || azureOpenAIApiKey

api/server/services/Config/loadDefaultModels.js

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
2+
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
13
const {
24
getOpenAIModels,
35
getChatGPTBrowserModels,
46
getAnthropicModels,
57
} = require('~/server/services/ModelService');
6-
const { EModelEndpoint } = require('librechat-data-provider');
7-
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
88

99
const fitlerAssistantModels = (str) => {
1010
return /gpt-4|gpt-3\\.5/i.test(str) && !/vision|instruct/i.test(str);
@@ -21,18 +21,7 @@ async function loadDefaultModels() {
2121
[EModelEndpoint.openAI]: openAI,
2222
[EModelEndpoint.azureOpenAI]: azureOpenAI,
2323
[EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels),
24-
[EModelEndpoint.google]: [
25-
'chat-bison',
26-
'chat-bison-32k',
27-
'codechat-bison',
28-
'codechat-bison-32k',
29-
'text-bison',
30-
'text-bison-32k',
31-
'text-unicorn',
32-
'code-gecko',
33-
'code-bison',
34-
'code-bison-32k',
35-
],
24+
[EModelEndpoint.google]: defaultModels[EModelEndpoint.google],
3625
[EModelEndpoint.bingAI]: ['BingAI', 'Sydney'],
3726
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
3827
[EModelEndpoint.gptPlugins]: gptPlugins,

api/server/services/Endpoints/google/initializeClient.js

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
const { GoogleClient } = require('~/app');
2-
const { EModelEndpoint } = require('librechat-data-provider');
2+
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
33
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
44

55
const initializeClient = async ({ req, res, endpointOption }) => {
@@ -11,14 +11,26 @@ const initializeClient = async ({ req, res, endpointOption }) => {
1111
if (expiresAt && isUserProvided) {
1212
checkUserKeyExpiry(
1313
expiresAt,
14-
'Your Google key has expired. Please provide your JSON credentials again.',
14+
'Your Google Credentials have expired. Please provide your Service Account JSON Key or Generative Language API Key again.',
1515
);
1616
userKey = await getUserKey({ userId: req.user.id, name: EModelEndpoint.google });
1717
}
1818

19-
const apiKey = isUserProvided ? userKey : require('~/data/auth.json');
19+
let serviceKey = {};
20+
try {
21+
serviceKey = require('~/data/auth.json');
22+
} catch (e) {
23+
// Do nothing
24+
}
25+
26+
const credentials = isUserProvided
27+
? userKey
28+
: {
29+
[AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey,
30+
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
31+
};
2032

21-
const client = new GoogleClient(apiKey, {
33+
const client = new GoogleClient(credentials, {
2234
req,
2335
res,
2436
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
@@ -28,7 +40,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
2840

2941
return {
3042
client,
31-
apiKey,
43+
credentials,
3244
};
3345
};
3446

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
const initializeClient = require('./initializeClient');
2+
const { GoogleClient } = require('~/app');
3+
const { checkUserKeyExpiry, getUserKey } = require('../../UserService');
4+
5+
jest.mock('../../UserService', () => ({
6+
checkUserKeyExpiry: jest.fn().mockImplementation((expiresAt, errorMessage) => {
7+
if (new Date(expiresAt) < new Date()) {
8+
throw new Error(errorMessage);
9+
}
10+
}),
11+
getUserKey: jest.fn().mockImplementation(() => ({})),
12+
}));
13+
14+
describe('google/initializeClient', () => {
15+
afterEach(() => {
16+
jest.clearAllMocks();
17+
});
18+
19+
test('should initialize GoogleClient with user-provided credentials', async () => {
20+
process.env.GOOGLE_KEY = 'user_provided';
21+
process.env.GOOGLE_REVERSE_PROXY = 'http://reverse.proxy';
22+
process.env.PROXY = 'http://proxy';
23+
24+
const expiresAt = new Date(Date.now() + 60000).toISOString();
25+
26+
const req = {
27+
body: { key: expiresAt },
28+
user: { id: '123' },
29+
};
30+
const res = {};
31+
const endpointOption = { modelOptions: { model: 'default-model' } };
32+
33+
const { client, credentials } = await initializeClient({ req, res, endpointOption });
34+
35+
expect(getUserKey).toHaveBeenCalledWith({ userId: '123', name: 'google' });
36+
expect(client).toBeInstanceOf(GoogleClient);
37+
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
38+
expect(client.options.proxy).toBe('http://proxy');
39+
expect(credentials).toEqual({});
40+
});
41+
42+
test('should initialize GoogleClient with service key credentials', async () => {
43+
process.env.GOOGLE_KEY = 'service_key';
44+
process.env.GOOGLE_REVERSE_PROXY = 'http://reverse.proxy';
45+
process.env.PROXY = 'http://proxy';
46+
47+
const req = {
48+
body: { key: null },
49+
user: { id: '123' },
50+
};
51+
const res = {};
52+
const endpointOption = { modelOptions: { model: 'default-model' } };
53+
54+
const { client, credentials } = await initializeClient({ req, res, endpointOption });
55+
56+
expect(client).toBeInstanceOf(GoogleClient);
57+
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
58+
expect(client.options.proxy).toBe('http://proxy');
59+
expect(credentials).toEqual({
60+
GOOGLE_SERVICE_KEY: {},
61+
GOOGLE_API_KEY: 'service_key',
62+
});
63+
});
64+
65+
test('should handle expired user-provided key', async () => {
66+
process.env.GOOGLE_KEY = 'user_provided';
67+
68+
const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired
69+
const req = {
70+
body: { key: expiresAt },
71+
user: { id: '123' },
72+
};
73+
const res = {};
74+
const endpointOption = { modelOptions: { model: 'default-model' } };
75+
76+
checkUserKeyExpiry.mockImplementation((expiresAt, errorMessage) => {
77+
throw new Error(errorMessage);
78+
});
79+
80+
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
81+
/Your Google Credentials have expired/,
82+
);
83+
});
84+
});

api/server/services/ModelService.js

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
const Keyv = require('keyv');
22
const axios = require('axios');
33
const HttpsProxyAgent = require('https-proxy-agent');
4+
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
45
const { isEnabled } = require('~/server/utils');
56
const keyvRedis = require('~/cache/keyvRedis');
67
const { extractBaseURL } = require('~/utils');
@@ -117,15 +118,7 @@ const getChatGPTBrowserModels = () => {
117118
};
118119

119120
const getAnthropicModels = () => {
120-
let models = [
121-
'claude-2.1',
122-
'claude-2',
123-
'claude-1.2',
124-
'claude-1',
125-
'claude-1-100k',
126-
'claude-instant-1',
127-
'claude-instant-1-100k',
128-
];
121+
let models = defaultModels[EModelEndpoint.anthropic];
129122
if (ANTHROPIC_MODELS) {
130123
models = String(ANTHROPIC_MODELS).split(',');
131124
}

0 commit comments

Comments
 (0)