Skip to content
71 changes: 38 additions & 33 deletions src/providers/bedrock/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ import {
BedrocMistralStreamChunk,
} from './complete';
import { BedrockErrorResponse } from './embed';
import {
transformMessagesForLLama2Prompt,
transformMessagesForLLama3Prompt,
transformMessagesForMistralPrompt,
} from './utils';

interface AnthropicTool {
name: string;
Expand Down Expand Up @@ -373,28 +378,13 @@ export const BedrockCohereChatCompleteConfig: ProviderConfig = {
},
};

export const BedrockLLamaChatCompleteConfig: ProviderConfig = {
export const BedrockLlama2ChatCompleteConfig: ProviderConfig = {
messages: {
param: 'prompt',
required: true,
transform: (params: Params) => {
let prompt: string = '';
if (!!params.messages) {
let messages: Message[] = params.messages;
messages.forEach((msg, index) => {
if (index === 0 && msg.role === 'system') {
prompt += `system: ${msg.content}\n`;
} else if (msg.role == 'user') {
prompt += `user: ${msg.content}\n`;
} else if (msg.role == 'assistant') {
prompt += `assistant: ${msg.content}\n`;
} else {
prompt += `${msg.role}: ${msg.content}\n`;
}
});
prompt += 'Assistant:';
}
return prompt;
if (!params.messages) return '';
return transformMessagesForLLama2Prompt(params.messages);
},
},
max_tokens: {
Expand Down Expand Up @@ -423,27 +413,42 @@ export const BedrockLLamaChatCompleteConfig: ProviderConfig = {
},
};

export const BedrockLlama3ChatCompleteConfig: ProviderConfig = {
messages: {
param: 'prompt',
required: true,
transform: (params: Params) => {
if (!params.messages) return '';
return transformMessagesForLLama3Prompt(params.messages);
},
},
max_tokens: {
param: 'max_gen_len',
default: 512,
min: 1,
},
temperature: {
param: 'temperature',
default: 0.5,
min: 0,
max: 1,
},
top_p: {
param: 'top_p',
default: 0.9,
min: 0,
max: 1,
},
};

export const BedrockMistralChatCompleteConfig: ProviderConfig = {
messages: {
param: 'prompt',
required: true,
transform: (params: Params) => {
let prompt: string = '';
if (!!params.messages) {
let messages: Message[] = params.messages;
messages.forEach((msg, index) => {
if (index === 0 && msg.role === 'system') {
prompt += `system: ${msg.content}\n`;
} else if (msg.role == 'user') {
prompt += `user: ${msg.content}\n`;
} else if (msg.role == 'assistant') {
prompt += `assistant: ${msg.content}\n`;
} else {
prompt += `${msg.role}: ${msg.content}\n`;
}
});
prompt += 'Assistant:';
}
if (!!params.messages)
prompt = transformMessagesForMistralPrompt(params.messages);
return prompt;
},
},
Expand Down
36 changes: 36 additions & 0 deletions src/providers/bedrock/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
export const LLAMA_2_SPECIAL_TOKENS = {
BEGINNING_OF_SENTENCE: '<s>',
END_OF_SENTENCE: '</s>',
CONVERSATION_TURN_START: '[INST]',
CONVERSATION_TURN_END: '[/INST]',
SYSTEM_MESSAGE_START: '<<SYS>>\n',
SYSTEM_MESSAGE_END: '\n<</SYS>>\n\n',
};

export const LLAMA_3_SPECIAL_TOKENS = {
PROMPT_START: '<|begin_of_text|>',
PROMPT_END: '<|end_of_text|>',
PADDING: '<|finetune_right_pad_id|>',
ROLE_START: '<|start_header_id|>',
ROLE_END: '<|end_header_id|>',
END_OF_MESSAGE: '<|eom_id|>',
END_OF_TURN: '<|eot_id|>',
TOOL_CALL: '<|python_tag|>',
};

export const MISTRAL_CONTROL_TOKENS = {
UNKNOWN: '<unk>',
BEGINNING_OF_SENTENCE: '<s>',
END_OF_SENTENCE: '</s>',
CONVERSATION_TURN_START: '[INST]',
CONVERSATION_TURN_END: '[/INST]',
AVAILABLE_TOOLS_START: '[AVAILABLE_TOOLS]',
AVAILABLE_TOOLS_END: '[/AVAILABLE_TOOLS]',
TOOL_RESULTS_START: '[TOOL_RESULTS]',
TOOL_RESULTS_END: '[/TOOL_RESULTS]',
TOOL_CALLS_START: '[TOOL_CALLS]',
PADDING: '<pad>',
PREFIX: '[PREFIX]',
MIDDLE: '[MIDDLE]',
SUFFIX: '[SUFFIX]',
};
10 changes: 8 additions & 2 deletions src/providers/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import {
BedrockCohereChatCompleteConfig,
BedrockCohereChatCompleteResponseTransform,
BedrockCohereChatCompleteStreamChunkTransform,
BedrockLLamaChatCompleteConfig,
BedrockLlamaChatCompleteResponseTransform,
BedrockLlamaChatCompleteStreamChunkTransform,
BedrockTitanChatCompleteResponseTransform,
Expand All @@ -20,6 +19,8 @@ import {
BedrockMistralChatCompleteConfig,
BedrockMistralChatCompleteResponseTransform,
BedrockMistralChatCompleteStreamChunkTransform,
BedrockLlama3ChatCompleteConfig,
BedrockLlama2ChatCompleteConfig,
} from './chatComplete';
import {
BedrockAI21CompleteConfig,
Expand Down Expand Up @@ -56,6 +57,7 @@ const BedrockConfig: ProviderConfigs = {
getConfig: (params: Params) => {
const providerModel = params.model;
const provider = providerModel?.split('.')[0];
const model = providerModel?.split('.')[1];
switch (provider) {
case ANTHROPIC:
return {
Expand Down Expand Up @@ -86,9 +88,13 @@ const BedrockConfig: ProviderConfigs = {
},
};
case 'meta':
const chatCompleteConfig =
model?.search('llama3') === -1
? BedrockLlama2ChatCompleteConfig
: BedrockLlama3ChatCompleteConfig;
return {
complete: BedrockLLamaCompleteConfig,
chatComplete: BedrockLLamaChatCompleteConfig,
chatComplete: chatCompleteConfig,
api: BedrockAPIConfig,
responseTransforms: {
'stream-complete': BedrockLlamaCompleteStreamChunkTransform,
Expand Down
103 changes: 103 additions & 0 deletions src/providers/bedrock/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import { SignatureV4 } from '@smithy/signature-v4';
import { Sha256 } from '@aws-crypto/sha256-js';
import { ContentType, Message, MESSAGE_ROLES } from '../../types/requestBody';
import {
LLAMA_2_SPECIAL_TOKENS,
LLAMA_3_SPECIAL_TOKENS,
MISTRAL_CONTROL_TOKENS,
} from './constants';

export const generateAWSHeaders = async (
body: Record<string, any>,
Expand Down Expand Up @@ -38,3 +44,100 @@ export const generateAWSHeaders = async (
const signed = await signer.sign(request);
return signed.headers;
};

/*
Helper function to use inside reduce to convert ContentType array to string
*/
const convertContentTypesToString = (acc: string, curr: ContentType) => {
if (curr.type !== 'text') return acc;
acc += curr.text + '\n';
return acc;
};

/*
Handle messages of both string and ContentType array
*/
const getMessageContent = (message: Message) => {
if (message === undefined) return '';
if (typeof message.content === 'object') {
return message.content.reduce(convertContentTypesToString, '');
}
return message.content || '';
};

/*
This function transforms the messages for the LLama 3.1 prompt.
It adds the special tokens to the beginning and end of the prompt.
refer: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1
NOTE: Portkey does not restrict messages to alternate user and assistant roles, this is to support more flexible use cases.
*/
export const transformMessagesForLLama3Prompt = (messages: Message[]) => {
let prompt: string = '';
prompt += LLAMA_3_SPECIAL_TOKENS.PROMPT_START + '\n';
messages.forEach((msg, index) => {
prompt +=
LLAMA_3_SPECIAL_TOKENS.ROLE_START +
msg.role +
LLAMA_3_SPECIAL_TOKENS.ROLE_END +
'\n';
prompt += getMessageContent(msg) + LLAMA_3_SPECIAL_TOKENS.END_OF_TURN;
});
prompt +=
LLAMA_3_SPECIAL_TOKENS.ROLE_START +
MESSAGE_ROLES.ASSISTANT +
LLAMA_3_SPECIAL_TOKENS.ROLE_END +
'\n';
return prompt;
};

/*
This function transforms the messages for the LLama 2 prompt.
It combines the system message with the first user message,
and then attaches the message pairs.
Finally, it adds the last message to the prompt.
refer: https://github.com/meta-llama/llama/blob/main/llama/generation.py#L284-L395
*/
export const transformMessagesForLLama2Prompt = (messages: Message[]) => {
let finalPrompt: string = '';
// combine system message with first user message
if (messages.length > 0 && messages[0].role === MESSAGE_ROLES.SYSTEM) {
messages[0].content =
LLAMA_2_SPECIAL_TOKENS.SYSTEM_MESSAGE_START +
getMessageContent(messages[0]) +
LLAMA_2_SPECIAL_TOKENS.SYSTEM_MESSAGE_END +
getMessageContent(messages[1]);
}
messages = [messages[0], ...messages.slice(2)];
// attach message pairs
for (let i = 1; i < messages.length; i += 2) {
let prompt = getMessageContent(messages[i - 1]);
let answer = getMessageContent(messages[i]);
finalPrompt += `${LLAMA_2_SPECIAL_TOKENS.BEGINNING_OF_SENTENCE}${LLAMA_2_SPECIAL_TOKENS.CONVERSATION_TURN_START} ${prompt} ${LLAMA_2_SPECIAL_TOKENS.CONVERSATION_TURN_END} ${answer} ${LLAMA_2_SPECIAL_TOKENS.END_OF_SENTENCE}`;
}
if (messages.length % 2 === 1) {
finalPrompt += `${LLAMA_2_SPECIAL_TOKENS.BEGINNING_OF_SENTENCE}${LLAMA_2_SPECIAL_TOKENS.CONVERSATION_TURN_START} ${getMessageContent(messages[messages.length - 1])} ${LLAMA_2_SPECIAL_TOKENS.CONVERSATION_TURN_END}`;
}
return finalPrompt;
};

/*
refer: https://docs.mistral.ai/guides/tokenization/
refer: https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja
*/
export const transformMessagesForMistralPrompt = (messages: Message[]) => {
let finalPrompt: string = `${MISTRAL_CONTROL_TOKENS.BEGINNING_OF_SENTENCE}`;
// Mistral does not support system messages. (ref: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/discussions/14)
if (messages.length > 0 && messages[0].role === MESSAGE_ROLES.SYSTEM) {
messages[0].content =
getMessageContent(messages[0]) + '\n' + getMessageContent(messages[1]);
messages[0].role = MESSAGE_ROLES.USER;
}
for (const message of messages) {
if (message.role === MESSAGE_ROLES.USER) {
finalPrompt += `${MISTRAL_CONTROL_TOKENS.CONVERSATION_TURN_START} ${message.content} ${MISTRAL_CONTROL_TOKENS.CONVERSATION_TURN_END}`;
} else {
finalPrompt += ` ${message.content} ${MISTRAL_CONTROL_TOKENS.END_OF_SENTENCE}`;
}
}
return finalPrompt;
};
8 changes: 8 additions & 0 deletions src/types/requestBody.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ export interface ToolCall {
};
}

export enum MESSAGE_ROLES {
SYSTEM = 'system',
USER = 'user',
ASSISTANT = 'assistant',
FUNCTION = 'function',
TOOL = 'tool',
}

export type OpenAIMessageRole =
| 'system'
| 'user'
Expand Down
Loading