Skip to content

Commit 798e876

Browse files
authored
👓 feat: Vision Support for Assistants (#2195)
* refactor(assistants/chat): use promises to speed up initialization, initialize shared variables, include `attachedFileIds` to streamRunManager * chore: additional typedefs * fix(OpenAIClient): handle edge case where attachments promise is resolved * feat: createVisionPrompt * feat: Vision Support for Assistants
1 parent 1f0fb49 commit 798e876

File tree

16 files changed

+371
-95
lines changed

16 files changed

+371
-95
lines changed

api/app/clients/OpenAIClient.js

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ class OpenAIClient extends BaseClient {
9292
}
9393

9494
this.defaultVisionModel = this.options.visionModel ?? 'gpt-4-vision-preview';
95-
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
95+
if (typeof this.options.attachments?.then === 'function') {
96+
this.options.attachments.then((attachments) => this.checkVisionRequest(attachments));
97+
} else {
98+
this.checkVisionRequest(this.options.attachments);
99+
}
96100

97101
const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {};
98102
if (OPENROUTER_API_KEY && !this.azure) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/**
2+
* Generates a prompt instructing the user to describe an image in detail, tailored to different types of visual content.
3+
* @param {boolean} pluralized - Whether to pluralize the prompt for multiple images.
4+
* @returns {string} - The generated vision prompt.
5+
*/
6+
const createVisionPrompt = (pluralized = false) => {
7+
return `Please describe the image${
8+
pluralized ? 's' : ''
9+
} in detail, covering relevant aspects such as:
10+
11+
For photographs, illustrations, or artwork:
12+
- The main subject(s) and their appearance, positioning, and actions
13+
- The setting, background, and any notable objects or elements
14+
- Colors, lighting, and overall mood or atmosphere
15+
- Any interesting details, textures, or patterns
16+
- The style, technique, or medium used (if discernible)
17+
18+
For screenshots or images containing text:
19+
- The content and purpose of the text
20+
- The layout, formatting, and organization of the information
21+
- Any notable visual elements, such as logos, icons, or graphics
22+
- The overall context or message conveyed by the screenshot
23+
24+
For graphs, charts, or data visualizations:
25+
- The type of graph or chart (e.g., bar graph, line chart, pie chart)
26+
- The variables being compared or analyzed
27+
- Any trends, patterns, or outliers in the data
28+
- The axis labels, scales, and units of measurement
29+
- The title, legend, and any additional context provided
30+
31+
Be as specific and descriptive as possible while maintaining clarity and concision.`;
32+
};
33+
34+
module.exports = createVisionPrompt;

api/app/clients/prompts/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ const handleInputs = require('./handleInputs');
44
const instructions = require('./instructions');
55
const titlePrompts = require('./titlePrompts');
66
const truncateText = require('./truncateText');
7+
const createVisionPrompt = require('./createVisionPrompt');
78
const createContextHandlers = require('./createContextHandlers');
89

910
module.exports = {
@@ -13,5 +14,6 @@ module.exports = {
1314
...instructions,
1415
...titlePrompts,
1516
truncateText,
17+
createVisionPrompt,
1618
createContextHandlers,
1719
};

api/server/routes/assistants/chat.js

Lines changed: 159 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ const {
44
Constants,
55
RunStatus,
66
CacheKeys,
7+
FileSources,
78
ContentTypes,
89
EModelEndpoint,
910
ViolationTypes,
11+
ImageVisionTool,
1012
AssistantStreamEvents,
1113
} = require('librechat-data-provider');
1214
const {
@@ -17,9 +19,10 @@ const {
1719
addThreadMetadata,
1820
saveAssistantMessage,
1921
} = require('~/server/services/Threads');
22+
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
2023
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
2124
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
22-
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
25+
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
2326
const { createRun, StreamRunManager } = require('~/server/services/Runs');
2427
const { getTransactions } = require('~/models/Transaction');
2528
const checkBalance = require('~/models/checkBalance');
@@ -100,6 +103,16 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
100103
let parentMessageId = _parentId;
101104
/** @type {TMessage[]} */
102105
let previousMessages = [];
106+
/** @type {import('librechat-data-provider').TConversation | null} */
107+
let conversation = null;
108+
/** @type {string[]} */
109+
let file_ids = [];
110+
/** @type {Set<string>} */
111+
let attachedFileIds = new Set();
112+
/** @type {TMessage | null} */
113+
let requestMessage = null;
114+
/** @type {undefined | Promise<ChatCompletion>} */
115+
let visionPromise;
103116

104117
const userMessageId = v4();
105118
const responseMessageId = v4();
@@ -258,7 +271,10 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
258271
throw new Error('Missing assistant_id');
259272
}
260273

261-
if (isEnabled(process.env.CHECK_BALANCE)) {
274+
const checkBalanceBeforeRun = async () => {
275+
if (!isEnabled(process.env.CHECK_BALANCE)) {
276+
return;
277+
}
262278
const transactions =
263279
(await getTransactions({
264280
user: req.user.id,
@@ -288,7 +304,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
288304
amount: promptTokens,
289305
},
290306
});
291-
}
307+
};
292308

293309
/** @type {{ openai: OpenAIClient }} */
294310
const { openai: _openai, client } = await initializeClient({
@@ -300,103 +316,168 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
300316

301317
openai = _openai;
302318

303-
// if (thread_id) {
304-
// previousMessages = await checkMessageGaps({ openai, thread_id, conversationId });
305-
// }
306-
307319
if (previousMessages.length) {
308320
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
309321
}
310322

311-
const userMessage = {
323+
let userMessage = {
312324
role: 'user',
313325
content: text,
314326
metadata: {
315327
messageId: userMessageId,
316328
},
317329
};
318330

319-
let thread_file_ids = [];
320-
if (convoId) {
321-
const convo = await getConvo(req.user.id, convoId);
322-
if (convo && convo.file_ids) {
323-
thread_file_ids = convo.file_ids;
324-
}
331+
/** @type {CreateRunBody | undefined} */
332+
const body = {
333+
assistant_id,
334+
model,
335+
};
336+
337+
if (promptPrefix) {
338+
body.additional_instructions = promptPrefix;
325339
}
326340

327-
const file_ids = files.map(({ file_id }) => file_id);
328-
if (file_ids.length || thread_file_ids.length) {
329-
userMessage.file_ids = file_ids;
330-
openai.attachedFileIds = new Set([...file_ids, ...thread_file_ids]);
341+
if (instructions) {
342+
body.instructions = instructions;
331343
}
332344

333-
// TODO: may allow multiple messages to be created beforehand in a future update
334-
const initThreadBody = {
335-
messages: [userMessage],
336-
metadata: {
337-
user: req.user.id,
338-
conversationId,
339-
},
345+
const getRequestFileIds = async () => {
346+
let thread_file_ids = [];
347+
if (convoId) {
348+
const convo = await getConvo(req.user.id, convoId);
349+
if (convo && convo.file_ids) {
350+
thread_file_ids = convo.file_ids;
351+
}
352+
}
353+
354+
file_ids = files.map(({ file_id }) => file_id);
355+
if (file_ids.length || thread_file_ids.length) {
356+
userMessage.file_ids = file_ids;
357+
attachedFileIds = new Set([...file_ids, ...thread_file_ids]);
358+
}
340359
};
341360

342-
const result = await initThread({ openai, body: initThreadBody, thread_id });
343-
thread_id = result.thread_id;
361+
const addVisionPrompt = async () => {
362+
if (!req.body.endpointOption.attachments) {
363+
return;
364+
}
344365

345-
createOnTextProgress({
346-
openai,
347-
conversationId,
348-
userMessageId,
349-
messageId: responseMessageId,
350-
thread_id,
351-
});
366+
const assistant = await openai.beta.assistants.retrieve(assistant_id);
367+
const visionToolIndex = assistant.tools.findIndex(
368+
(tool) => tool.function.name === ImageVisionTool.function.name,
369+
);
352370

353-
const requestMessage = {
354-
user: req.user.id,
355-
text,
356-
messageId: userMessageId,
357-
parentMessageId,
358-
// TODO: make sure client sends correct format for `files`, use zod
359-
files,
360-
file_ids,
361-
conversationId,
362-
isCreatedByUser: true,
363-
assistant_id,
364-
thread_id,
365-
model: assistant_id,
366-
};
371+
if (visionToolIndex === -1) {
372+
return;
373+
}
367374

368-
previousMessages.push(requestMessage);
375+
const attachments = await req.body.endpointOption.attachments;
376+
let visionMessage = {
377+
role: 'user',
378+
content: '',
379+
};
380+
const files = await client.addImageURLs(visionMessage, attachments);
381+
if (!visionMessage.image_urls?.length) {
382+
return;
383+
}
369384

370-
await saveUserMessage({ ...requestMessage, model });
385+
const imageCount = visionMessage.image_urls.length;
386+
const plural = imageCount > 1;
387+
visionMessage.content = createVisionPrompt(plural);
388+
visionMessage = formatMessage({ message: visionMessage, endpoint: EModelEndpoint.openAI });
371389

372-
const conversation = {
373-
conversationId,
374-
// TODO: title feature
375-
title: 'New Chat',
376-
endpoint: EModelEndpoint.assistants,
377-
promptPrefix: promptPrefix,
378-
instructions: instructions,
379-
assistant_id,
380-
// model,
381-
};
390+
visionPromise = openai.chat.completions.create({
391+
model: 'gpt-4-vision-preview',
392+
messages: [visionMessage],
393+
max_tokens: 4000,
394+
});
382395

383-
if (file_ids.length) {
384-
conversation.file_ids = file_ids;
385-
}
396+
const pluralized = plural ? 's' : '';
397+
body.additional_instructions = `${
398+
body.additional_instructions ? `${body.additional_instructions}\n` : ''
399+
}The user has uploaded ${imageCount} image${pluralized}.
400+
Use the \`${ImageVisionTool.function.name}\` tool to retrieve ${
401+
plural ? '' : 'a '
402+
}detailed text description${pluralized} for ${plural ? 'each' : 'the'} image${pluralized}.`;
386403

387-
/** @type {CreateRunBody} */
388-
const body = {
389-
assistant_id,
390-
model,
404+
return files;
391405
};
392406

393-
if (promptPrefix) {
394-
body.additional_instructions = promptPrefix;
395-
}
407+
const initializeThread = async () => {
408+
/** @type {[ undefined | MongoFile[]]}*/
409+
const [processedFiles] = await Promise.all([addVisionPrompt(), getRequestFileIds()]);
410+
// TODO: may allow multiple messages to be created beforehand in a future update
411+
const initThreadBody = {
412+
messages: [userMessage],
413+
metadata: {
414+
user: req.user.id,
415+
conversationId,
416+
},
417+
};
396418

397-
if (instructions) {
398-
body.instructions = instructions;
399-
}
419+
if (processedFiles) {
420+
for (const file of processedFiles) {
421+
if (file.source !== FileSources.openai) {
422+
attachedFileIds.delete(file.file_id);
423+
const index = file_ids.indexOf(file.file_id);
424+
if (index > -1) {
425+
file_ids.splice(index, 1);
426+
}
427+
}
428+
}
429+
430+
userMessage.file_ids = file_ids;
431+
}
432+
433+
const result = await initThread({ openai, body: initThreadBody, thread_id });
434+
thread_id = result.thread_id;
435+
436+
createOnTextProgress({
437+
openai,
438+
conversationId,
439+
userMessageId,
440+
messageId: responseMessageId,
441+
thread_id,
442+
});
443+
444+
requestMessage = {
445+
user: req.user.id,
446+
text,
447+
messageId: userMessageId,
448+
parentMessageId,
449+
// TODO: make sure client sends correct format for `files`, use zod
450+
files,
451+
file_ids,
452+
conversationId,
453+
isCreatedByUser: true,
454+
assistant_id,
455+
thread_id,
456+
model: assistant_id,
457+
};
458+
459+
previousMessages.push(requestMessage);
460+
461+
/* asynchronous */
462+
saveUserMessage({ ...requestMessage, model });
463+
464+
conversation = {
465+
conversationId,
466+
title: 'New Chat',
467+
endpoint: EModelEndpoint.assistants,
468+
promptPrefix: promptPrefix,
469+
instructions: instructions,
470+
assistant_id,
471+
// model,
472+
};
473+
474+
if (file_ids.length) {
475+
conversation.file_ids = file_ids;
476+
}
477+
};
478+
479+
const promises = [initializeThread(), checkBalanceBeforeRun()];
480+
await Promise.all(promises);
400481

401482
const sendInitialResponse = () => {
402483
sendMessage(res, {
@@ -421,6 +502,8 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
421502

422503
const processRun = async (retry = false) => {
423504
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
505+
openai.attachedFileIds = attachedFileIds;
506+
openai.visionPromise = visionPromise;
424507
if (retry) {
425508
response = await runAssistant({
426509
openai,
@@ -463,9 +546,11 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
463546
req,
464547
res,
465548
openai,
549+
handlers,
466550
thread_id,
551+
visionPromise,
552+
attachedFileIds,
467553
responseMessage: openai.responseMessage,
468-
handlers,
469554
// streamOptions: {
470555

471556
// },

api/server/services/Runs/StreamRunManager.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class StreamRunManager {
5959
this.messages = [];
6060
/** @type {string} */
6161
this.text = '';
62+
/** @type {Set<string>} */
63+
this.attachedFileIds = fields.attachedFileIds;
64+
/** @type {undefined | Promise<ChatCompletion>} */
65+
this.visionPromise = fields.visionPromise;
6266

6367
/**
6468
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}

0 commit comments

Comments
 (0)