Skip to content

Commit 7cb9f85

Browse files
committed
🏗️ fix: Agents Token Spend Race Conditions, Add Auto-refill Tx, Add Relevant Tests (#6480)
* 🏗️ refactor: Improve spendTokens logic to handle zero completion tokens and enhance test coverage * 🏗️ test: Add tests to ensure balance does not go below zero when spending tokens * 🏗️ fix: Ensure proper continuation in AgentClient when handling errors * fix: spend token race conditions * 🏗️ test: Add test for handling multiple concurrent transactions with high balance * fix: Handle Omni models prompt prefix handling for user messages with array content in OpenAIClient * refactor: Update checkBalance import paths to use new balanceMethods module * refactor: Update checkBalance imports and implement updateBalance function for atomic balance updates * fix: import from replace method * feat: Add createAutoRefillTransaction method to handle non-balance updating transactions * refactor: Move auto-refill logic to balanceMethods and enhance checkBalance functionality * feat: Implement logging for auto-refill transactions in balance checks * refactor: Remove logRefill calls from multiple client and handler files * refactor: Move balance checking and auto-refill logic to balanceMethods for improved structure * refactor: Simplify balance check calls by removing unnecessary balanceRecord assignments * fix: Prevent negative rawAmount in spendTokens when promptTokens is zero * fix: Update balanceMethods to use Balance model for findOneAndUpdate * chore: import order * refactor: remove unused txMethods file to streamline codebase * feat: enhance updateBalance and createAutoRefillTransaction methods to support additional parameters for improved balance management
1 parent 6bf6d53 commit 7cb9f85

File tree

13 files changed

+805
-277
lines changed

13 files changed

+805
-277
lines changed

api/app/clients/BaseClient.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ const {
1111
Constants,
1212
} = require('librechat-data-provider');
1313
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
14+
const { checkBalance } = require('~/models/balanceMethods');
1415
const { truncateToolCallOutputs } = require('./prompts');
1516
const { addSpaceIfNeeded } = require('~/server/utils');
16-
const checkBalance = require('~/models/checkBalance');
1717
const { getFiles } = require('~/models/File');
1818
const TextStream = require('./TextStream');
1919
const { logger } = require('~/config');

api/app/clients/OpenAIClient.js

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const { SplitStreamHandler, GraphEvents } = require('@librechat/agents');
55
const {
66
Constants,
77
ImageDetail,
8+
ContentTypes,
89
EModelEndpoint,
910
resolveHeaders,
1011
KnownEndpoints,
@@ -505,8 +506,24 @@ class OpenAIClient extends BaseClient {
505506
if (promptPrefix && this.isOmni === true) {
506507
const lastUserMessageIndex = payload.findLastIndex((message) => message.role === 'user');
507508
if (lastUserMessageIndex !== -1) {
508-
payload[lastUserMessageIndex].content =
509-
`${promptPrefix}\n${payload[lastUserMessageIndex].content}`;
509+
if (Array.isArray(payload[lastUserMessageIndex].content)) {
510+
const firstTextPartIndex = payload[lastUserMessageIndex].content.findIndex(
511+
(part) => part.type === ContentTypes.TEXT,
512+
);
513+
if (firstTextPartIndex !== -1) {
514+
const firstTextPart = payload[lastUserMessageIndex].content[firstTextPartIndex];
515+
payload[lastUserMessageIndex].content[firstTextPartIndex].text =
516+
`${promptPrefix}\n${firstTextPart.text}`;
517+
} else {
518+
payload[lastUserMessageIndex].content.unshift({
519+
type: ContentTypes.TEXT,
520+
text: promptPrefix,
521+
});
522+
}
523+
} else {
524+
payload[lastUserMessageIndex].content =
525+
`${promptPrefix}\n${payload[lastUserMessageIndex].content}`;
526+
}
510527
}
511528
}
512529

api/app/clients/PluginsClient.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_pars
55
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
66
const { processFileURL } = require('~/server/services/Files/process');
77
const { EModelEndpoint } = require('librechat-data-provider');
8+
const { checkBalance } = require('~/models/balanceMethods');
89
const { formatLangChainMessages } = require('./prompts');
9-
const checkBalance = require('~/models/checkBalance');
1010
const { extractBaseURL } = require('~/utils');
1111
const { loadTools } = require('./tools/util');
1212
const { logger } = require('~/config');

api/app/clients/callbacks/createStartHandler.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ const { promptTokensEstimate } = require('openai-chat-tokens');
22
const { EModelEndpoint, supportsBalanceCheck } = require('librechat-data-provider');
33
const { formatFromLangChain } = require('~/app/clients/prompts');
44
const { getBalanceConfig } = require('~/server/services/Config');
5-
const checkBalance = require('~/models/checkBalance');
5+
const { checkBalance } = require('~/models/balanceMethods');
66
const { logger } = require('~/config');
77

88
const createStartHandler = ({

api/models/Balance.js

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,4 @@
11
const mongoose = require('mongoose');
22
const { balanceSchema } = require('@librechat/data-schemas');
3-
const { getMultiplier } = require('./tx');
4-
const { logger } = require('~/config');
5-
6-
/**
7-
* Adds a time interval to a given date.
8-
* @param {Date} date - The starting date.
9-
* @param {number} value - The numeric value of the interval.
10-
* @param {'seconds'|'minutes'|'hours'|'days'|'weeks'|'months'} unit - The unit of time.
11-
* @returns {Date} A new Date representing the starting date plus the interval.
12-
*/
13-
const addIntervalToDate = (date, value, unit) => {
14-
const result = new Date(date);
15-
switch (unit) {
16-
case 'seconds':
17-
result.setSeconds(result.getSeconds() + value);
18-
break;
19-
case 'minutes':
20-
result.setMinutes(result.getMinutes() + value);
21-
break;
22-
case 'hours':
23-
result.setHours(result.getHours() + value);
24-
break;
25-
case 'days':
26-
result.setDate(result.getDate() + value);
27-
break;
28-
case 'weeks':
29-
result.setDate(result.getDate() + value * 7);
30-
break;
31-
case 'months':
32-
result.setMonth(result.getMonth() + value);
33-
break;
34-
default:
35-
break;
36-
}
37-
return result;
38-
};
39-
40-
balanceSchema.statics.check = async function ({
41-
user,
42-
model,
43-
endpoint,
44-
valueKey,
45-
tokenType,
46-
amount,
47-
endpointTokenConfig,
48-
}) {
49-
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig });
50-
const tokenCost = amount * multiplier;
51-
52-
// Retrieve the complete balance record
53-
let record = await this.findOne({ user }).lean();
54-
if (!record) {
55-
logger.debug('[Balance.check] No balance record found for user', { user });
56-
return {
57-
canSpend: false,
58-
balance: 0,
59-
tokenCost,
60-
};
61-
}
62-
let balance = record.tokenCredits;
63-
64-
logger.debug('[Balance.check] Initial state', {
65-
user,
66-
model,
67-
endpoint,
68-
valueKey,
69-
tokenType,
70-
amount,
71-
balance,
72-
multiplier,
73-
endpointTokenConfig: !!endpointTokenConfig,
74-
});
75-
76-
// Only perform auto-refill if spending would bring the balance to 0 or below
77-
if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) {
78-
const lastRefillDate = new Date(record.lastRefill);
79-
const nextRefillDate = addIntervalToDate(
80-
lastRefillDate,
81-
record.refillIntervalValue,
82-
record.refillIntervalUnit,
83-
);
84-
const now = new Date();
85-
86-
if (now >= nextRefillDate) {
87-
record = await this.findOneAndUpdate(
88-
{ user },
89-
{
90-
$inc: { tokenCredits: record.refillAmount },
91-
$set: { lastRefill: new Date() },
92-
},
93-
{ new: true },
94-
).lean();
95-
balance = record.tokenCredits;
96-
logger.debug('[Balance.check] Auto-refill performed', { balance });
97-
}
98-
}
99-
100-
logger.debug('[Balance.check] Token cost', { tokenCost });
101-
102-
return { canSpend: balance >= tokenCost, balance, tokenCost };
103-
};
1043

1054
module.exports = mongoose.model('Balance', balanceSchema);

api/models/Transaction.js

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,45 @@ const { getBalanceConfig } = require('~/server/services/Config');
44
const { getMultiplier, getCacheMultiplier } = require('./tx');
55
const { logger } = require('~/config');
66
const Balance = require('./Balance');
7+
78
const cancelRate = 1.15;
89

10+
/**
11+
* Updates a user's token balance based on a transaction.
12+
*
13+
* @async
14+
* @function
15+
* @param {Object} params - The function parameters.
16+
* @param {string} params.user - The user ID.
17+
* @param {number} params.incrementValue - The value to increment the balance by (can be negative).
18+
* @param {import('mongoose').UpdateQuery<import('@librechat/data-schemas').IBalance>['$set']} params.setValues
19+
* @returns {Promise<Object>} Returns the updated balance response.
20+
*/
21+
const updateBalance = async ({ user, incrementValue, setValues }) => {
22+
// Use findOneAndUpdate with a conditional update to make the balance update atomic
23+
// This prevents race conditions when multiple transactions are processed concurrently
24+
const balanceResponse = await Balance.findOneAndUpdate(
25+
{ user },
26+
[
27+
{
28+
$set: {
29+
tokenCredits: {
30+
$cond: {
31+
if: { $lt: [{ $add: ['$tokenCredits', incrementValue] }, 0] },
32+
then: 0,
33+
else: { $add: ['$tokenCredits', incrementValue] },
34+
},
35+
},
36+
...setValues,
37+
},
38+
},
39+
],
40+
{ upsert: true, new: true },
41+
).lean();
42+
43+
return balanceResponse;
44+
};
45+
946
/** Method to calculate and set the tokenValue for a transaction */
1047
transactionSchema.methods.calculateTokenValue = function () {
1148
if (!this.valueKey || !this.tokenType) {
@@ -21,6 +58,39 @@ transactionSchema.methods.calculateTokenValue = function () {
2158
}
2259
};
2360

61+
/**
62+
* New static method to create an auto-refill transaction that does NOT trigger a balance update.
63+
* @param {object} txData - Transaction data.
64+
* @param {string} txData.user - The user ID.
65+
* @param {string} txData.tokenType - The type of token.
66+
* @param {string} txData.context - The context of the transaction.
67+
* @param {number} txData.rawAmount - The raw amount of tokens.
68+
* @returns {Promise<object>} - The created transaction.
69+
*/
70+
transactionSchema.statics.createAutoRefillTransaction = async function (txData) {
71+
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
72+
return;
73+
}
74+
const transaction = new this(txData);
75+
transaction.endpointTokenConfig = txData.endpointTokenConfig;
76+
transaction.calculateTokenValue();
77+
await transaction.save();
78+
79+
const balanceResponse = await updateBalance({
80+
user: transaction.user,
81+
incrementValue: txData.rawAmount,
82+
setValues: { lastRefill: new Date() },
83+
});
84+
const result = {
85+
rate: transaction.rate,
86+
user: transaction.user.toString(),
87+
balance: balanceResponse.tokenCredits,
88+
};
89+
logger.debug('[Balance.check] Auto-refill performed', result);
90+
result.transaction = transaction;
91+
return result;
92+
};
93+
2494
/**
2595
* Static method to create a transaction and update the balance
2696
* @param {txData} txData - Transaction data.
@@ -42,18 +112,12 @@ transactionSchema.statics.create = async function (txData) {
42112
return;
43113
}
44114

45-
let balanceResponse = await Balance.findOne({ user: transaction.user }).lean();
46115
let incrementValue = transaction.tokenValue;
47116

48-
if (balanceResponse && balanceResponse.tokenCredits + incrementValue < 0) {
49-
incrementValue = -balanceResponse.tokenCredits;
50-
}
51-
52-
balanceResponse = await Balance.findOneAndUpdate(
53-
{ user: transaction.user },
54-
{ $inc: { tokenCredits: incrementValue } },
55-
{ upsert: true, new: true },
56-
).lean();
117+
const balanceResponse = await updateBalance({
118+
user: transaction.user,
119+
incrementValue,
120+
});
57121

58122
return {
59123
rate: transaction.rate,
@@ -84,18 +148,12 @@ transactionSchema.statics.createStructured = async function (txData) {
84148
return;
85149
}
86150

87-
let balanceResponse = await Balance.findOne({ user: transaction.user }).lean();
88151
let incrementValue = transaction.tokenValue;
89152

90-
if (balanceResponse && balanceResponse.tokenCredits + incrementValue < 0) {
91-
incrementValue = -balanceResponse.tokenCredits;
92-
}
93-
94-
balanceResponse = await Balance.findOneAndUpdate(
95-
{ user: transaction.user },
96-
{ $inc: { tokenCredits: incrementValue } },
97-
{ upsert: true, new: true },
98-
).lean();
153+
const balanceResponse = await updateBalance({
154+
user: transaction.user,
155+
incrementValue,
156+
});
99157

100158
return {
101159
rate: transaction.rate,

0 commit comments

Comments
 (0)