Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/providers/google-vertex-ai/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ export const GoogleApiConfig: ProviderAPIConfig = {
return googleUrlMap.get(mappedFn) || `${projectRoute}`;
}

case 'mistralai':
case 'anthropic': {
if (mappedFn === 'chatComplete' || mappedFn === 'messages') {
return `${projectRoute}/publishers/${provider}/models/${model}:rawPredict`;
Expand Down
16 changes: 16 additions & 0 deletions src/providers/google-vertex-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ import {
VertexAnthropicMessagesConfig,
VertexAnthropicMessagesResponseTransform,
} from './messages';
import {
GetMistralAIChatCompleteResponseTransform,
GetMistralAIChatCompleteStreamChunkTransform,
MistralAIChatCompleteConfig,
} from '../mistral-ai/chatComplete';

const VertexConfig: ProviderConfigs = {
api: VertexApiConfig,
Expand Down Expand Up @@ -162,6 +167,17 @@ const VertexConfig: ProviderConfigs = {
...responseTransforms,
},
};
case 'mistralai':
return {
chatComplete: MistralAIChatCompleteConfig,
api: GoogleApiConfig,
responseTransforms: {
chatComplete:
GetMistralAIChatCompleteResponseTransform(GOOGLE_VERTEX_AI),
'stream-chatComplete':
GetMistralAIChatCompleteStreamChunkTransform(GOOGLE_VERTEX_AI),
},
};
default:
return baseConfig;
}
Expand Down
4 changes: 3 additions & 1 deletion src/providers/google-vertex-ai/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ export const getModelAndProvider = (modelString: string) => {
const modelStringParts = modelString.split('.');
if (
modelStringParts.length > 1 &&
['google', 'anthropic', 'meta', 'endpoints'].includes(modelStringParts[0])
['google', 'anthropic', 'meta', 'endpoints', 'mistralai'].includes(
modelStringParts[0]
)
) {
provider = modelStringParts[0];
model = modelStringParts.slice(1).join('.');
Expand Down
183 changes: 89 additions & 94 deletions src/providers/mistral-ai/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { MISTRAL_AI } from '../../globals';
import { Params } from '../../types/requestBody';
import {
ChatCompletionResponse,
Expand All @@ -17,6 +16,9 @@ export const MistralAIChatCompleteConfig: ProviderConfig = {
param: 'model',
required: true,
default: 'mistral-tiny',
transform: (params: Params) => {
return params.model?.replace('mistralai.', '');
},
},
messages: {
param: 'messages',
Expand Down Expand Up @@ -152,104 +154,97 @@ interface MistralAIStreamChunk {
};
}

export const MistralAIChatCompleteResponseTransform: (
response: MistralAIChatCompleteResponse | MistralAIErrorResponse,
responseStatus: number,
responseHeaders: Headers,
strictOpenAiCompliance: boolean,
gatewayRequestUrl: string,
gatewayRequest: Params
) => ChatCompletionResponse | ErrorResponse = (
response,
responseStatus,
responseHeaders,
strictOpenAiCompliance,
gatewayRequestUrl,
gatewayRequest
) => {
if ('message' in response && responseStatus !== 200) {
return generateErrorResponse(
{
message: response.message,
type: response.type,
param: response.param,
code: response.code,
},
MISTRAL_AI
);
}
export const GetMistralAIChatCompleteResponseTransform = (provider: string) => {
return (
response: MistralAIChatCompleteResponse | MistralAIErrorResponse,
responseStatus: number,
_responseHeaders: Headers,
strictOpenAiCompliance: boolean,
_gatewayRequestUrl: string,
_gatewayRequest: Params
): ChatCompletionResponse | ErrorResponse => {
if ('message' in response && responseStatus !== 200) {
return generateErrorResponse(
{
message: response.message,
type: response.type,
param: response.param,
code: response.code,
},
provider
);
}

if ('choices' in response) {
return {
id: response.id,
object: response.object,
created: response.created,
model: response.model,
provider: MISTRAL_AI,
choices: response.choices.map((c) => ({
index: c.index,
message: {
role: c.message.role,
content: c.message.content,
tool_calls: c.message.tool_calls,
if ('choices' in response) {
return {
id: response.id,
object: response.object,
created: response.created,
model: response.model,
provider: provider,
choices: response.choices.map((c) => ({
index: c.index,
message: {
role: c.message.role,
content: c.message.content,
tool_calls: c.message.tool_calls,
},
finish_reason: transformFinishReason(
c.finish_reason as MISTRAL_AI_FINISH_REASON,
strictOpenAiCompliance
),
})),
usage: {
prompt_tokens: response.usage?.prompt_tokens,
completion_tokens: response.usage?.completion_tokens,
total_tokens: response.usage?.total_tokens,
},
finish_reason: transformFinishReason(
c.finish_reason as MISTRAL_AI_FINISH_REASON,
strictOpenAiCompliance
),
})),
usage: {
prompt_tokens: response.usage?.prompt_tokens,
completion_tokens: response.usage?.completion_tokens,
total_tokens: response.usage?.total_tokens,
},
};
}
};
}

return generateInvalidProviderResponseError(response, MISTRAL_AI);
return generateInvalidProviderResponseError(response, provider);
};
};

export const MistralAIChatCompleteStreamChunkTransform: (
response: string,
fallbackId: string,
streamState: any,
strictOpenAiCompliance: boolean,
gatewayRequest: Params
) => string | string[] = (
responseChunk,
fallbackId,
_streamState,
strictOpenAiCompliance,
_gatewayRequest
export const GetMistralAIChatCompleteStreamChunkTransform = (
provider: string
) => {
let chunk = responseChunk.trim();
chunk = chunk.replace(/^data: /, '');
chunk = chunk.trim();
if (chunk === '[DONE]') {
return `data: ${chunk}\n\n`;
}
const parsedChunk: MistralAIStreamChunk = JSON.parse(chunk);
const finishReason = parsedChunk.choices[0].finish_reason
? transformFinishReason(
parsedChunk.choices[0].finish_reason as MISTRAL_AI_FINISH_REASON,
strictOpenAiCompliance
)
: null;
return (
`data: ${JSON.stringify({
id: parsedChunk.id,
object: parsedChunk.object,
created: parsedChunk.created,
model: parsedChunk.model,
provider: MISTRAL_AI,
choices: [
{
index: parsedChunk.choices[0].index,
delta: parsedChunk.choices[0].delta,
finish_reason: finishReason,
},
],
...(parsedChunk.usage ? { usage: parsedChunk.usage } : {}),
})}` + '\n\n'
);
responseChunk: string,
fallbackId: string,
_streamState: any,
strictOpenAiCompliance: boolean,
_gatewayRequest: Params
) => {
let chunk = responseChunk.trim();
chunk = chunk.replace(/^data: /, '');
chunk = chunk.trim();
if (chunk === '[DONE]') {
return `data: ${chunk}\n\n`;
}
const parsedChunk: MistralAIStreamChunk = JSON.parse(chunk);
const finishReason = parsedChunk.choices[0].finish_reason
? transformFinishReason(
parsedChunk.choices[0].finish_reason as MISTRAL_AI_FINISH_REASON,
strictOpenAiCompliance
)
: null;
return (
`data: ${JSON.stringify({
id: parsedChunk.id,
object: parsedChunk.object,
created: parsedChunk.created,
model: parsedChunk.model,
provider: provider,
choices: [
{
index: parsedChunk.choices[0].index,
delta: parsedChunk.choices[0].delta,
finish_reason: finishReason,
},
],
...(parsedChunk.usage ? { usage: parsedChunk.usage } : {}),
})}` + '\n\n'
);
};
};
10 changes: 6 additions & 4 deletions src/providers/mistral-ai/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { MISTRAL_AI } from '../../globals';
import { ProviderConfigs } from '../types';
import MistralAIAPIConfig from './api';
import {
GetMistralAIChatCompleteResponseTransform,
GetMistralAIChatCompleteStreamChunkTransform,
MistralAIChatCompleteConfig,
MistralAIChatCompleteResponseTransform,
MistralAIChatCompleteStreamChunkTransform,
} from './chatComplete';
import { MistralAIEmbedConfig, MistralAIEmbedResponseTransform } from './embed';

Expand All @@ -12,8 +13,9 @@ const MistralAIConfig: ProviderConfigs = {
embed: MistralAIEmbedConfig,
api: MistralAIAPIConfig,
responseTransforms: {
chatComplete: MistralAIChatCompleteResponseTransform,
'stream-chatComplete': MistralAIChatCompleteStreamChunkTransform,
chatComplete: GetMistralAIChatCompleteResponseTransform(MISTRAL_AI),
'stream-chatComplete':
GetMistralAIChatCompleteStreamChunkTransform(MISTRAL_AI),
embed: MistralAIEmbedResponseTransform,
},
};
Expand Down