Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
085aea3
refactor: add gemini-pro to google Models list; use defaultModels for…
danny-avila Dec 14, 2023
4cf3f05
refactor(SetKeyDialog): create useMultipleKeys hook to use for Azure,…
danny-avila Dec 14, 2023
14310f6
refactor(useUserKey): change variable names to make keyName setting m…
danny-avila Dec 14, 2023
4c56bb1
refactor(FileUpload): allow passing container className string
danny-avila Dec 14, 2023
a8f1b4a
feat(GoogleClient): Gemini support
danny-avila Dec 15, 2023
d35e917
refactor(GoogleClient): alternate stream speed for Gemini models
danny-avila Dec 15, 2023
83a38dd
feat(Gemini): styling/settings configuration for Gemini
danny-avila Dec 15, 2023
680f83d
refactor(GoogleClient): substract max response tokens from max contex…
danny-avila Dec 15, 2023
9c444e0
refactor(tokens): correct google max token counts and subtract max re…
danny-avila Dec 15, 2023
f1ec9a1
feat(google/initializeClient): handle both local and user_provided cr…
danny-avila Dec 15, 2023
d54fb8c
fix(GoogleClient): catch if credentials are undefined, handle if serv…
danny-avila Dec 15, 2023
f4c03ec
refactor(loadAsyncEndpoints/google): activate Google endpoint if eith…
danny-avila Dec 15, 2023
43bd222
docs: updated Google configuration
danny-avila Dec 15, 2023
bafe222
fix(ci): Mock import of Service Account Key JSON file (auth.json)
danny-avila Dec 15, 2023
b3951cc
Update apis_and_tokens.md
danny-avila Dec 15, 2023
74173f4
feat: increase max output tokens slider for gemini pro
danny-avila Dec 15, 2023
12370a2
refactor(GoogleSettings): handle max and default maxOutputTokens on m…
danny-avila Dec 15, 2023
f1ba578
chore: add sensitive redact regex
danny-avila Dec 15, 2023
15eb2f8
docs: add warning about data privacy
danny-avila Dec 15, 2023
7e7c729
Merge branch 'google-gemini' of https://github.com/danny-avila/LibreC…
danny-avila Dec 15, 2023
8fb7551
Update apis_and_tokens.md
danny-avila Dec 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 62 additions & 14 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
const { google } = require('googleapis');
const { Agent, ProxyAgent } = require('undici');
const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { getResponseSender, EModelEndpoint, endpointSettings } = require('librechat-data-provider');
const {
getResponseSender,
EModelEndpoint,
endpointSettings,
AuthKeys,
} = require('librechat-data-provider');
const { getModelMaxTokens } = require('~/utils');
const { formatMessage } = require('./prompts');
const BaseClient = require('./BaseClient');
Expand All @@ -21,11 +27,24 @@ const settings = endpointSettings[EModelEndpoint.google];
class GoogleClient extends BaseClient {
constructor(credentials, options = {}) {
super('apiKey', options);
this.credentials = credentials;
this.client_email = credentials.client_email;
this.project_id = credentials.project_id;
this.private_key = credentials.private_key;
let creds = {};

if (typeof credentials === 'string') {
creds = JSON.parse(credentials);
} else if (credentials) {
creds = credentials;
}

const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
this.serviceKey =
serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : serviceKey ?? {};
this.client_email = this.serviceKey.client_email;
this.private_key = this.serviceKey.private_key;
this.project_id = this.serviceKey.project_id;
this.access_token = null;

this.apiKey = creds[AuthKeys.GOOGLE_API_KEY];

if (options.skipSetOptions) {
return;
}
Expand Down Expand Up @@ -85,7 +104,7 @@ class GoogleClient extends BaseClient {
this.options = options;
}

this.options.examples = this.options.examples
this.options.examples = (this.options.examples ?? [])
.filter((ex) => ex)
.filter((obj) => obj.input.content !== '' && obj.output.content !== '');

Expand All @@ -103,15 +122,24 @@ class GoogleClient extends BaseClient {
// stop: modelOptions.stop // no stop method for now
};

this.isChatModel = this.modelOptions.model.includes('chat');
// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
const { isGenerativeModel } = this;
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
const { isChatModel } = this;
this.isTextModel = !isChatModel && /code|text/.test(this.modelOptions.model);
this.isTextModel =
!isGenerativeModel && !isChatModel && /code|text/.test(this.modelOptions.model);
const { isTextModel } = this;

this.maxContextTokens = getModelMaxTokens(this.modelOptions.model, EModelEndpoint.google);
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit.
this.maxResponseTokens = this.modelOptions.maxOutputTokens || settings.maxOutputTokens.default;

if (this.maxContextTokens > 32000) {
this.maxContextTokens = this.maxContextTokens - this.maxResponseTokens;
}

this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;

Expand All @@ -134,7 +162,7 @@ class GoogleClient extends BaseClient {
this.userLabel = this.options.userLabel || 'User';
this.modelLabel = this.options.modelLabel || 'Assistant';

if (isChatModel) {
if (isChatModel || isGenerativeModel) {
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
// without tripping the stop sequences, so I'm using "||>" instead.
Expand Down Expand Up @@ -189,6 +217,16 @@ class GoogleClient extends BaseClient {
}

buildMessages(messages = [], parentMessageId) {
if (!this.isGenerativeModel && !this.project_id) {
throw new Error(
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
);
} else if (this.isGenerativeModel && (!this.apiKey || this.apiKey === 'user_provided')) {
throw new Error(
'[GoogleClient] an API Key is required for Gemini models (Generative Language API)',
);
}

if (this.isTextModel) {
return this.buildMessagesPrompt(messages, parentMessageId);
}
Expand Down Expand Up @@ -398,6 +436,16 @@ class GoogleClient extends BaseClient {
return res.data;
}

createLLM(clientOptions) {
if (this.isGenerativeModel) {
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
}

return this.isTextModel
? new GoogleVertexAI(clientOptions)
: new ChatGoogleVertexAI(clientOptions);
}

async getCompletion(_payload, options = {}) {
const { onProgress, abortController } = options;
const { parameters, instances } = _payload;
Expand All @@ -408,7 +456,7 @@ class GoogleClient extends BaseClient {
let clientOptions = {
authOptions: {
credentials: {
...this.credentials,
...this.serviceKey,
},
projectId: this.project_id,
},
Expand Down Expand Up @@ -436,9 +484,7 @@ class GoogleClient extends BaseClient {
clientOptions.examples = examples;
}

const model = this.isTextModel
? new GoogleVertexAI(clientOptions)
: new ChatGoogleVertexAI(clientOptions);
const model = this.createLLM(clientOptions);

let reply = '';
const messages = this.isTextModel
Expand All @@ -457,7 +503,9 @@ class GoogleClient extends BaseClient {
});

for await (const chunk of stream) {
await this.generateTextStream(chunk?.content ?? chunk, onProgress, { delay: 7 });
await this.generateTextStream(chunk?.content ?? chunk, onProgress, {
delay: this.isGenerativeModel ? 12 : 8,
});
reply += chunk?.content ?? chunk;
}

Expand Down
2 changes: 1 addition & 1 deletion api/config/parsers.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ const winston = require('winston');
const traverse = require('traverse');
const { klona } = require('klona/full');

const sensitiveKeys = [/^sk-\w+$/];
const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/];

/**
* Determines if a given key string is sensitive.
Expand Down
1 change: 1 addition & 0 deletions api/jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ module.exports = {
],
moduleNameMapper: {
'~/(.*)': '<rootDir>/$1',
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
},
};
1 change: 1 addition & 0 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"@azure/search-documents": "^12.0.0",
"@keyv/mongo": "^2.1.8",
"@keyv/redis": "^2.8.0",
"@langchain/google-genai": "^0.0.2",
"axios": "^1.3.4",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",
Expand Down
6 changes: 3 additions & 3 deletions api/server/services/Config/loadAsyncEndpoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go
*/
async function loadAsyncEndpoints() {
let i = 0;
let key, googleUserProvides;
let serviceKey, googleUserProvides;
try {
key = require('~/data/auth.json');
serviceKey = require('~/data/auth.json');
} catch (e) {
if (i === 0) {
i++;
Expand All @@ -33,7 +33,7 @@ async function loadAsyncEndpoints() {
}
const plugins = transformToolsToMap(tools);

const google = key || googleUserProvides ? { userProvide: googleUserProvides } : false;
const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;

const gptPlugins =
openAIApiKey || azureOpenAIApiKey
Expand Down
17 changes: 3 additions & 14 deletions api/server/services/Config/loadDefaultModels.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;
const {
getOpenAIModels,
getChatGPTBrowserModels,
getAnthropicModels,
} = require('~/server/services/ModelService');
const { EModelEndpoint } = require('librechat-data-provider');
const { useAzurePlugins } = require('~/server/services/Config/EndpointService').config;

const fitlerAssistantModels = (str) => {
return /gpt-4|gpt-3\\.5/i.test(str) && !/vision|instruct/i.test(str);
Expand All @@ -21,18 +21,7 @@ async function loadDefaultModels() {
[EModelEndpoint.openAI]: openAI,
[EModelEndpoint.azureOpenAI]: azureOpenAI,
[EModelEndpoint.assistant]: openAI.filter(fitlerAssistantModels),
[EModelEndpoint.google]: [
'chat-bison',
'chat-bison-32k',
'codechat-bison',
'codechat-bison-32k',
'text-bison',
'text-bison-32k',
'text-unicorn',
'code-gecko',
'code-bison',
'code-bison-32k',
],
[EModelEndpoint.google]: defaultModels[EModelEndpoint.google],
[EModelEndpoint.bingAI]: ['BingAI', 'Sydney'],
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
[EModelEndpoint.gptPlugins]: gptPlugins,
Expand Down
22 changes: 17 additions & 5 deletions api/server/services/Endpoints/google/initializeClient.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const { GoogleClient } = require('~/app');
const { EModelEndpoint } = require('librechat-data-provider');
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');

const initializeClient = async ({ req, res, endpointOption }) => {
Expand All @@ -11,14 +11,26 @@ const initializeClient = async ({ req, res, endpointOption }) => {
if (expiresAt && isUserProvided) {
checkUserKeyExpiry(
expiresAt,
'Your Google key has expired. Please provide your JSON credentials again.',
'Your Google Credentials have expired. Please provide your Service Account JSON Key or Generative Language API Key again.',
);
userKey = await getUserKey({ userId: req.user.id, name: EModelEndpoint.google });
}

const apiKey = isUserProvided ? userKey : require('~/data/auth.json');
let serviceKey = {};
try {
serviceKey = require('~/data/auth.json');
} catch (e) {
// Do nothing
}

const credentials = isUserProvided
? userKey
: {
[AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey,
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
};

const client = new GoogleClient(apiKey, {
const client = new GoogleClient(credentials, {
req,
res,
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
Expand All @@ -28,7 +40,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {

return {
client,
apiKey,
credentials,
};
};

Expand Down
84 changes: 84 additions & 0 deletions api/server/services/Endpoints/google/initializeClient.spec.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
const initializeClient = require('./initializeClient');
const { GoogleClient } = require('~/app');
const { checkUserKeyExpiry, getUserKey } = require('../../UserService');

jest.mock('../../UserService', () => ({
checkUserKeyExpiry: jest.fn().mockImplementation((expiresAt, errorMessage) => {
if (new Date(expiresAt) < new Date()) {
throw new Error(errorMessage);
}
}),
getUserKey: jest.fn().mockImplementation(() => ({})),
}));

describe('google/initializeClient', () => {
afterEach(() => {
jest.clearAllMocks();
});

test('should initialize GoogleClient with user-provided credentials', async () => {
process.env.GOOGLE_KEY = 'user_provided';
process.env.GOOGLE_REVERSE_PROXY = 'http://reverse.proxy';
process.env.PROXY = 'http://proxy';

const expiresAt = new Date(Date.now() + 60000).toISOString();

const req = {
body: { key: expiresAt },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };

const { client, credentials } = await initializeClient({ req, res, endpointOption });

expect(getUserKey).toHaveBeenCalledWith({ userId: '123', name: 'google' });
expect(client).toBeInstanceOf(GoogleClient);
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
expect(client.options.proxy).toBe('http://proxy');
expect(credentials).toEqual({});
});

test('should initialize GoogleClient with service key credentials', async () => {
process.env.GOOGLE_KEY = 'service_key';
process.env.GOOGLE_REVERSE_PROXY = 'http://reverse.proxy';
process.env.PROXY = 'http://proxy';

const req = {
body: { key: null },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };

const { client, credentials } = await initializeClient({ req, res, endpointOption });

expect(client).toBeInstanceOf(GoogleClient);
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
expect(client.options.proxy).toBe('http://proxy');
expect(credentials).toEqual({
GOOGLE_SERVICE_KEY: {},
GOOGLE_API_KEY: 'service_key',
});
});

test('should handle expired user-provided key', async () => {
process.env.GOOGLE_KEY = 'user_provided';

const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired
const req = {
body: { key: expiresAt },
user: { id: '123' },
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };

checkUserKeyExpiry.mockImplementation((expiresAt, errorMessage) => {
throw new Error(errorMessage);
});

await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
/Your Google Credentials have expired/,
);
});
});
11 changes: 2 additions & 9 deletions api/server/services/ModelService.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const Keyv = require('keyv');
const axios = require('axios');
const HttpsProxyAgent = require('https-proxy-agent');
const { EModelEndpoint, defaultModels } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { extractBaseURL } = require('~/utils');
Expand Down Expand Up @@ -117,15 +118,7 @@ const getChatGPTBrowserModels = () => {
};

const getAnthropicModels = () => {
let models = [
'claude-2.1',
'claude-2',
'claude-1.2',
'claude-1',
'claude-1-100k',
'claude-instant-1',
'claude-instant-1-100k',
];
let models = defaultModels[EModelEndpoint.anthropic];
if (ANTHROPIC_MODELS) {
models = String(ANTHROPIC_MODELS).split(',');
}
Expand Down
Loading