Skip to content

Commit ab302a4

Browse files
committed
♻️ fix: Prevent Instructions from Removal when nearing Max Context (#5516)
* refactor: getMessagesWithinTokenLimit to accept params object * refactor: always include instructions in payload if provided * ci: remove obsolete test * refactor: update logoutUser to accept request object and handle session destruction * test: enhance getMessagesWithinTokenLimit tests for instruction handling
1 parent 679c6e6 commit ab302a4

File tree

6 files changed

+185
-83
lines changed

6 files changed

+185
-83
lines changed

api/app/clients/AnthropicClient.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ class AnthropicClient extends BaseClient {
416416
}
417417

418418
let { context: messagesInWindow, remainingContextTokens } =
419-
await this.getMessagesWithinTokenLimit(formattedMessages);
419+
await this.getMessagesWithinTokenLimit({ messages: formattedMessages });
420420

421421
const tokenCountMap = orderedMessages
422422
.slice(orderedMessages.length - messagesInWindow.length)

api/app/clients/BaseClient.js

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -347,25 +347,38 @@ class BaseClient {
347347
* If the token limit would be exceeded by adding a message, that message is not added to the context and remains in the original array.
348348
* The method uses `push` and `pop` operations for efficient array manipulation, and reverses the context array at the end to maintain the original order of the messages.
349349
*
350-
* @param {Array} _messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
351-
* @param {number} [maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`.
352-
* @returns {Object} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
350+
* @param {Object} params
351+
* @param {TMessage[]} params.messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
352+
* @param {number} [params.maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`.
353+
* @param {{ role: 'system', content: text, tokenCount: number }} [params.instructions] - Instructions already added to the context at index 0.
354+
* @returns {Promise<{
355+
* context: TMessage[],
356+
* remainingContextTokens: number,
357+
* messagesToRefine: TMessage[],
358+
* summaryIndex: number,
359+
* }>} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
353360
* `context` is an array of messages that fit within the token limit.
354361
* `summaryIndex` is the index of the first message in the `messagesToRefine` array.
355362
* `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context.
356363
* `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
357364
*/
358-
async getMessagesWithinTokenLimit(_messages, maxContextTokens) {
365+
async getMessagesWithinTokenLimit({ messages: _messages, maxContextTokens, instructions }) {
359366
// Every reply is primed with <|start|>assistant<|message|>, so we
360367
// start with 3 tokens for the label after all messages have been counted.
361-
let currentTokenCount = 3;
362368
let summaryIndex = -1;
363-
let remainingContextTokens = maxContextTokens ?? this.maxContextTokens;
369+
let currentTokenCount = 3;
370+
const instructionsTokenCount = instructions?.tokenCount ?? 0;
371+
let remainingContextTokens =
372+
(maxContextTokens ?? this.maxContextTokens) - instructionsTokenCount;
364373
const messages = [..._messages];
365374

366375
const context = [];
376+
367377
if (currentTokenCount < remainingContextTokens) {
368378
while (messages.length > 0 && currentTokenCount < remainingContextTokens) {
379+
if (messages.length === 1 && instructions) {
380+
break;
381+
}
369382
const poppedMessage = messages.pop();
370383
const { tokenCount } = poppedMessage;
371384

@@ -379,6 +392,11 @@ class BaseClient {
379392
}
380393
}
381394

395+
if (instructions) {
396+
context.push(_messages[0]);
397+
messages.shift();
398+
}
399+
382400
const prunedMemory = messages;
383401
summaryIndex = prunedMemory.length - 1;
384402
remainingContextTokens -= currentTokenCount;
@@ -403,27 +421,38 @@ class BaseClient {
403421
if (instructions) {
404422
({ tokenCount, ..._instructions } = instructions);
405423
}
424+
406425
_instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount);
407-
let payload = this.addInstructions(formattedMessages, _instructions);
408-
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
426+
if (tokenCount && tokenCount > this.maxContextTokens) {
427+
const info = `${tokenCount} / ${this.maxContextTokens}`;
428+
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
429+
logger.warn(`Instructions token count exceeds max token count (${info}).`);
430+
throw new Error(errorMessage);
431+
}
432+
409433
if (this.clientName === EModelEndpoint.agents) {
410434
const { dbMessages, editedIndices } = truncateToolCallOutputs(
411-
orderedWithInstructions,
435+
orderedMessages,
412436
this.maxContextTokens,
413437
this.getTokenCountForMessage.bind(this),
414438
);
415439

416440
if (editedIndices.length > 0) {
417441
logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices);
418442
for (const index of editedIndices) {
419-
payload[index].content = dbMessages[index].content;
443+
formattedMessages[index].content = dbMessages[index].content;
420444
}
421-
orderedWithInstructions = dbMessages;
445+
orderedMessages = dbMessages;
422446
}
423447
}
424448

449+
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
450+
425451
let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
426-
await this.getMessagesWithinTokenLimit(orderedWithInstructions);
452+
await this.getMessagesWithinTokenLimit({
453+
messages: orderedWithInstructions,
454+
instructions,
455+
});
427456

428457
logger.debug('[BaseClient] Context Count (1/2)', {
429458
remainingContextTokens,
@@ -435,7 +464,9 @@ class BaseClient {
435464
let { shouldSummarize } = this;
436465

437466
// Calculate the difference in length to determine how many messages were discarded if any
438-
const { length } = payload;
467+
let payload;
468+
let { length } = formattedMessages;
469+
length += instructions != null ? 1 : 0;
439470
const diff = length - context.length;
440471
const firstMessage = orderedWithInstructions[0];
441472
const usePrevSummary =
@@ -445,18 +476,31 @@ class BaseClient {
445476
this.previous_summary.messageId === firstMessage.messageId;
446477

447478
if (diff > 0) {
448-
payload = payload.slice(diff);
479+
payload = formattedMessages.slice(diff);
449480
logger.debug(
450481
`[BaseClient] Difference between original payload (${length}) and context (${context.length}): ${diff}`,
451482
);
452483
}
453484

485+
payload = this.addInstructions(payload ?? formattedMessages, _instructions);
486+
454487
const latestMessage = orderedWithInstructions[orderedWithInstructions.length - 1];
455488
if (payload.length === 0 && !shouldSummarize && latestMessage) {
456489
const info = `${latestMessage.tokenCount} / ${this.maxContextTokens}`;
457490
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
458491
logger.warn(`Prompt token count exceeds max token count (${info}).`);
459492
throw new Error(errorMessage);
493+
} else if (
494+
_instructions &&
495+
payload.length === 1 &&
496+
payload[0].content === _instructions.content
497+
) {
498+
const info = `${tokenCount + 3} / ${this.maxContextTokens}`;
499+
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
500+
logger.warn(
501+
`Including instructions, the prompt token count exceeds remaining max token count (${info}).`,
502+
);
503+
throw new Error(errorMessage);
460504
}
461505

462506
if (usePrevSummary) {

api/app/clients/OpenAIClient.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,10 @@ ${convo}
931931
);
932932

933933
if (excessTokenCount > maxContextTokens) {
934-
({ context } = await this.getMessagesWithinTokenLimit(context, maxContextTokens));
934+
({ context } = await this.getMessagesWithinTokenLimit({
935+
messages: context,
936+
maxContextTokens,
937+
}));
935938
}
936939

937940
if (context.length === 0) {

api/app/clients/specs/BaseClient.test.js

Lines changed: 110 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ describe('BaseClient', () => {
159159
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
160160
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
161161

162-
const result = await TestClient.getMessagesWithinTokenLimit(messages);
162+
const result = await TestClient.getMessagesWithinTokenLimit({ messages });
163163

164164
expect(result.context).toEqual(expectedContext);
165165
expect(result.summaryIndex).toEqual(expectedIndex);
@@ -195,74 +195,14 @@ describe('BaseClient', () => {
195195
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
196196
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
197197

198-
const result = await TestClient.getMessagesWithinTokenLimit(messages);
198+
const result = await TestClient.getMessagesWithinTokenLimit({ messages });
199199

200200
expect(result.context).toEqual(expectedContext);
201201
expect(result.summaryIndex).toEqual(expectedIndex);
202202
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
203203
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
204204
});
205205

206-
test('handles context strategy correctly in handleContextStrategy()', async () => {
207-
TestClient.addInstructions = jest
208-
.fn()
209-
.mockReturnValue([
210-
{ content: 'Hello' },
211-
{ content: 'How can I help you?' },
212-
{ content: 'Please provide more details.' },
213-
{ content: 'I can assist you with that.' },
214-
]);
215-
TestClient.getMessagesWithinTokenLimit = jest.fn().mockReturnValue({
216-
context: [
217-
{ content: 'How can I help you?' },
218-
{ content: 'Please provide more details.' },
219-
{ content: 'I can assist you with that.' },
220-
],
221-
remainingContextTokens: 80,
222-
messagesToRefine: [{ content: 'Hello' }],
223-
summaryIndex: 3,
224-
});
225-
226-
TestClient.getTokenCount = jest.fn().mockReturnValue(40);
227-
228-
const instructions = { content: 'Please provide more details.' };
229-
const orderedMessages = [
230-
{ content: 'Hello' },
231-
{ content: 'How can I help you?' },
232-
{ content: 'Please provide more details.' },
233-
{ content: 'I can assist you with that.' },
234-
];
235-
const formattedMessages = [
236-
{ content: 'Hello' },
237-
{ content: 'How can I help you?' },
238-
{ content: 'Please provide more details.' },
239-
{ content: 'I can assist you with that.' },
240-
];
241-
const expectedResult = {
242-
payload: [
243-
{
244-
role: 'system',
245-
content: 'Refined answer',
246-
},
247-
{ content: 'How can I help you?' },
248-
{ content: 'Please provide more details.' },
249-
{ content: 'I can assist you with that.' },
250-
],
251-
promptTokens: expect.any(Number),
252-
tokenCountMap: {},
253-
messages: expect.any(Array),
254-
};
255-
256-
TestClient.shouldSummarize = true;
257-
const result = await TestClient.handleContextStrategy({
258-
instructions,
259-
orderedMessages,
260-
formattedMessages,
261-
});
262-
263-
expect(result).toEqual(expectedResult);
264-
});
265-
266206
describe('getMessagesForConversation', () => {
267207
it('should return an empty array if the parentMessageId does not exist', () => {
268208
const result = TestClient.constructor.getMessagesForConversation({
@@ -674,4 +614,112 @@ describe('BaseClient', () => {
674614
expect(calls[1][0].isCreatedByUser).toBe(false); // Second call should be for response message
675615
});
676616
});
617+
618+
describe('getMessagesWithinTokenLimit with instructions', () => {
619+
test('should always include instructions when present', async () => {
620+
TestClient.maxContextTokens = 50;
621+
const instructions = {
622+
role: 'system',
623+
content: 'System instructions',
624+
tokenCount: 20,
625+
};
626+
627+
const messages = [
628+
instructions,
629+
{ role: 'user', content: 'Hello', tokenCount: 10 },
630+
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
631+
];
632+
633+
const result = await TestClient.getMessagesWithinTokenLimit({
634+
messages,
635+
instructions,
636+
});
637+
638+
expect(result.context[0]).toBe(instructions);
639+
expect(result.remainingContextTokens).toBe(2);
640+
});
641+
642+
test('should handle case when messages exceed limit but instructions must be preserved', async () => {
643+
TestClient.maxContextTokens = 30;
644+
const instructions = {
645+
role: 'system',
646+
content: 'System instructions',
647+
tokenCount: 20,
648+
};
649+
650+
const messages = [
651+
instructions,
652+
{ role: 'user', content: 'Hello', tokenCount: 10 },
653+
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
654+
];
655+
656+
const result = await TestClient.getMessagesWithinTokenLimit({
657+
messages,
658+
instructions,
659+
});
660+
661+
// Should only include instructions and the last message that fits
662+
expect(result.context).toHaveLength(1);
663+
expect(result.context[0].content).toBe(instructions.content);
664+
expect(result.messagesToRefine).toHaveLength(2);
665+
expect(result.remainingContextTokens).toBe(7); // 30 - 20 - 3 (assistant label)
666+
});
667+
668+
test('should work correctly without instructions (1/2)', async () => {
669+
TestClient.maxContextTokens = 50;
670+
const messages = [
671+
{ role: 'user', content: 'Hello', tokenCount: 10 },
672+
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
673+
];
674+
675+
const result = await TestClient.getMessagesWithinTokenLimit({
676+
messages,
677+
});
678+
679+
expect(result.context).toHaveLength(2);
680+
expect(result.remainingContextTokens).toBe(22); // 50 - 10 - 15 - 3(assistant label)
681+
expect(result.messagesToRefine).toHaveLength(0);
682+
});
683+
684+
test('should work correctly without instructions (2/2)', async () => {
685+
TestClient.maxContextTokens = 30;
686+
const messages = [
687+
{ role: 'user', content: 'Hello', tokenCount: 10 },
688+
{ role: 'assistant', content: 'Hi there', tokenCount: 20 },
689+
];
690+
691+
const result = await TestClient.getMessagesWithinTokenLimit({
692+
messages,
693+
});
694+
695+
expect(result.context).toHaveLength(1);
696+
expect(result.remainingContextTokens).toBe(7);
697+
expect(result.messagesToRefine).toHaveLength(1);
698+
});
699+
700+
test('should handle case when only instructions fit within limit', async () => {
701+
TestClient.maxContextTokens = 25;
702+
const instructions = {
703+
role: 'system',
704+
content: 'System instructions',
705+
tokenCount: 20,
706+
};
707+
708+
const messages = [
709+
instructions,
710+
{ role: 'user', content: 'Hello', tokenCount: 10 },
711+
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
712+
];
713+
714+
const result = await TestClient.getMessagesWithinTokenLimit({
715+
messages,
716+
instructions,
717+
});
718+
719+
expect(result.context).toHaveLength(1);
720+
expect(result.context[0]).toBe(instructions);
721+
expect(result.messagesToRefine).toHaveLength(2);
722+
expect(result.remainingContextTokens).toBe(2); // 25 - 20 - 3(assistant label)
723+
});
724+
});
677725
});

api/server/controllers/auth/LogoutController.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const { logger } = require('~/config');
55
const logoutController = async (req, res) => {
66
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
77
try {
8-
const logout = await logoutUser(req.user._id, refreshToken);
8+
const logout = await logoutUser(req, refreshToken);
99
const { status, message } = logout;
1010
res.clearCookie('refreshToken');
1111
return res.status(status).send({ message });

0 commit comments

Comments
 (0)