Skip to content

feat(db & e2e): Enhance DB Schemas/Controllers and Improve E2E Tests #966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 18, 2023
2 changes: 1 addition & 1 deletion api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ class BaseClient {
}

async saveMessageToDatabase(message, endpointOptions, user = null) {
await saveMessage({ ...message, unfinished: false, cancelled: false });
await saveMessage({ ...message, user, unfinished: false, cancelled: false });
await saveConvo(user, {
conversationId: message.conversationId,
endpoint: this.options.endpoint,
Expand Down
17 changes: 17 additions & 0 deletions api/models/Conversation.js
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ module.exports = {
return { message: 'Error getting conversation title' };
}
},
/**
* Asynchronously deletes conversations and associated messages for a given user and filter.
*
* @async
* @function
* @param {string|ObjectId} user - The user's ID.
* @param {Object} filter - Additional filter criteria for the conversations to be deleted.
* @returns {Promise<{ n: number, ok: number, deletedCount: number, messages: { n: number, ok: number, deletedCount: number } }>}
* An object containing the count of deleted conversations and associated messages.
* @throws {Error} Throws an error if there's an issue with the database operations.
*
* @example
* const user = 'someUserId';
* const filter = { someField: 'someValue' };
* const result = await deleteConvos(user, filter);
* console.log(result); // { n: 5, ok: 1, deletedCount: 5, messages: { n: 10, ok: 1, deletedCount: 10 } }
*/
deleteConvos: async (user, filter) => {
let toRemove = await Conversation.find({ ...filter, user }).select('conversationId');
const ids = toRemove.map((instance) => instance.conversationId);
Expand Down
2 changes: 2 additions & 0 deletions api/models/Message.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module.exports = {
Message,

async saveMessage({
user,
messageId,
newMessageId,
conversationId,
Expand All @@ -33,6 +34,7 @@ module.exports = {
await Message.findOneAndUpdate(
{ messageId },
{
user,
messageId: newMessageId || messageId,
conversationId,
parentMessageId,
Expand Down
1 change: 1 addition & 0 deletions api/models/schema/convoSchema.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const convoSchema = mongoose.Schema(
},
user: {
type: String,
index: true,
default: null,
},
messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }],
Expand Down
5 changes: 5 additions & 0 deletions api/models/schema/messageSchema.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ const messageSchema = mongoose.Schema(
required: true,
meiliIndex: true,
},
user: {
type: String,
index: true,
default: null,
},
model: {
type: String,
},
Expand Down
3 changes: 2 additions & 1 deletion api/server/middleware/abortMiddleware.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const createAbortController = (req, res, getAbortData) => {
isCreatedByUser: false,
};

saveMessage(responseMessage);
saveMessage({ ...responseMessage, user: req.user.id });

return {
title: await getConvoTitle(req.user.id, conversationId),
Expand All @@ -80,6 +80,7 @@ const handleAbortError = async (res, req, error, data) => {
parentMessageId,
text: error.message,
shouldSaveMessage: true,
user: req.user.id,
};
const callback = async () => {
if (abortControllers.has(conversationId)) {
Expand Down
3 changes: 2 additions & 1 deletion api/server/middleware/denyRequest.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ const denyRequest = async (req, res, errorMessage) => {
_convoId && parentMessageId && parentMessageId !== '00000000-0000-0000-0000-000000000000';

if (shouldSaveMessage) {
await saveMessage(userMessage);
await saveMessage({ ...userMessage, user: req.user.id });
}

return await sendError(res, {
Expand All @@ -52,6 +52,7 @@ const denyRequest = async (req, res, errorMessage) => {
parentMessageId: userMessage.messageId,
text: responseText,
shouldSaveMessage,
user: req.user.id,
});
};

Expand Down
12 changes: 7 additions & 5 deletions api/server/routes/ask/anthropic.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
let responseMessageId;
let lastSavedTimestamp = 0;
let saveDelay = 100;
const user = req.user.id;

const getIds = (data) => {
userMessage = data.userMessage;
Expand All @@ -55,6 +56,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
unfinished: true,
cancelled: false,
error: false,
user,
});
}

Expand All @@ -80,7 +82,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
let response = await client.sendMessage(text, {
getIds,
// debug: true,
user: req.user.id,
user,
conversationId,
parentMessageId,
overrideParentMessageId,
Expand All @@ -98,18 +100,18 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
response.parentMessageId = overrideParentMessageId;
}

await saveConvo(req.user.id, {
await saveConvo(user, {
...endpointOption,
...endpointOption.modelOptions,
conversationId,
endpoint: 'anthropic',
});

await saveMessage(response);
await saveMessage({ ...response, user });
sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});
Expand Down
20 changes: 11 additions & 9 deletions api/server/routes/ask/askChatGPTBrowser.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ router.post('/', setHeaders, async (req, res) => {
});

if (!overrideParentMessageId) {
await saveMessage(userMessage);
await saveMessage({ ...userMessage, user: req.user.id });
await saveConvo(req.user.id, {
...userMessage,
...endpointOption,
Expand Down Expand Up @@ -80,7 +80,7 @@ const ask = async ({
res,
}) => {
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
const userId = req.user.id;
const user = req.user.id;
let responseMessageId = crypto.randomUUID();
let getPartialMessage = null;
try {
Expand All @@ -100,6 +100,7 @@ const ask = async ({
cancelled: false,
error: false,
isCreatedByUser: false,
user,
});
}
},
Expand All @@ -114,7 +115,7 @@ const ask = async ({
conversationId,
...endpointOption,
abortController,
userId,
userId: user,
onProgress: progressCallback.call(null, { res, text }),
onEventMessage: (eventMessage) => {
let data = null;
Expand Down Expand Up @@ -157,7 +158,7 @@ const ask = async ({
isCreatedByUser: false,
};

await saveMessage(responseMessage);
await saveMessage({ ...responseMessage, user });
responseMessage.messageId = newResponseMessageId;

// STEP2 update the conversation
Expand All @@ -181,7 +182,7 @@ const ask = async ({
}
}

await saveConvo(req.user.id, conversationUpdate);
await saveConvo(user, conversationUpdate);
conversationId = newConversationId;

// STEP3 update the user message
Expand All @@ -192,16 +193,17 @@ const ask = async ({
if (!overrideParentMessageId) {
await saveMessage({
...userMessage,
user,
messageId: userMessageId,
newMessageId: newUserMassageId,
});
}
userMessageId = newUserMassageId;

sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage,
});
Expand All @@ -210,7 +212,7 @@ const ask = async ({
if (userParentMessageId == '00000000-0000-0000-0000-000000000000') {
// const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage });
const title = await response.details.title;
await saveConvo(req.user.id, {
await saveConvo(user, {
conversationId: conversationId,
title,
});
Expand All @@ -227,7 +229,7 @@ const ask = async ({
isCreatedByUser: false,
text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"`,
};
await saveMessage(errorMessage);
await saveMessage({ ...errorMessage, user });
handleError(res, errorMessage);
}
};
Expand Down
27 changes: 15 additions & 12 deletions api/server/routes/ask/bingAI.js
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ router.post('/', setHeaders, async (req, res) => {
});

if (!overrideParentMessageId) {
await saveMessage(userMessage);
await saveMessage({ ...userMessage, user: req.user.id });
await saveConvo(req.user.id, {
...userMessage,
...endpointOption,
Expand Down Expand Up @@ -100,6 +100,7 @@ const ask = async ({
res,
}) => {
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
const user = req.user.id;

let responseMessageId = crypto.randomUUID();
const model = endpointOption?.jailbreak ? 'Sydney' : 'BingAI';
Expand All @@ -125,21 +126,22 @@ const ask = async ({
cancelled: false,
error: false,
isCreatedByUser: false,
user,
});
}
},
});
const abortController = new AbortController();
let bingConversationId = null;
if (!isNewConversation) {
const convo = await getConvo(req.user.id, conversationId);
const convo = await getConvo(user, conversationId);
bingConversationId = convo.bingConversationId;
}

try {
let response = await askBing({
text,
userId: req.user.id,
userId: user,
parentMessageId: userParentMessageId,
conversationId: bingConversationId ?? conversationId,
...endpointOption,
Expand Down Expand Up @@ -194,7 +196,7 @@ const ask = async ({
isCreatedByUser: false,
};

await saveMessage(responseMessage);
await saveMessage({ ...responseMessage, user });
responseMessage.messageId = newResponseMessageId;

let conversationUpdate = {
Expand All @@ -213,23 +215,24 @@ const ask = async ({
conversationUpdate.invocationId = response.invocationId;
}

await saveConvo(req.user.id, conversationUpdate);
await saveConvo(user, conversationUpdate);
userMessage.messageId = newUserMessageId;

// If response has parentMessageId, the fake userMessage.messageId should be updated to the real one.
if (!overrideParentMessageId) {
await saveMessage({
...userMessage,
user,
messageId: userMessageId,
newMessageId: newUserMessageId,
});
}
userMessageId = newUserMessageId;

sendMessage(res, {
title: await getConvoTitle(req.user.id, conversationId),
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage,
});
Expand All @@ -241,7 +244,7 @@ const ask = async ({
response: responseMessage,
});

await saveConvo(req.user.id, {
await saveConvo(user, {
conversationId: conversationId,
title,
});
Expand All @@ -263,12 +266,12 @@ const ask = async ({
isCreatedByUser: false,
};

saveMessage(responseMessage);
saveMessage({ ...responseMessage, user });

return {
title: await getConvoTitle(req.user.id, conversationId),
title: await getConvoTitle(user, conversationId),
final: true,
conversation: await getConvo(req.user.id, conversationId),
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage,
};
Expand All @@ -286,7 +289,7 @@ const ask = async ({
model,
isCreatedByUser: false,
};
await saveMessage(errorMessage);
await saveMessage({ ...errorMessage, user });
handleError(res, errorMessage);
}
}
Expand Down
Loading