Skip to content
16 changes: 15 additions & 1 deletion src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ export async function tryPost(

const baseUrl =
customHost || apiConfig.getBaseURL({ providerOptions: providerOption });

const endpoint = apiConfig.getEndpoint({
providerOptions: providerOption,
fn,
Expand Down Expand Up @@ -1000,11 +1001,23 @@ export function constructConfigFromRequestHeaders(
openaiProject: requestHeaders[`x-${POWERED_BY}-openai-project`],
};

const vertexConfig = {
const vertexConfig: Record<string, any> = {
vertexProjectId: requestHeaders[`x-${POWERED_BY}-vertex-project-id`],
vertexRegion: requestHeaders[`x-${POWERED_BY}-vertex-region`],
};

let vertexServiceAccountJson =
requestHeaders[`x-${POWERED_BY}-vertex-service-account-json`];
if (vertexServiceAccountJson) {
try {
vertexConfig.vertexServiceAccountJson = JSON.parse(
vertexServiceAccountJson
);
} catch (e) {
vertexConfig.vertexServiceAccountJson = null;
}
}

if (requestHeaders[`x-${POWERED_BY}-config`]) {
let parsedConfigJson = JSON.parse(requestHeaders[`x-${POWERED_BY}-config`]);

Expand Down Expand Up @@ -1053,6 +1066,7 @@ export function constructConfigFromRequestHeaders(
return convertKeysToCamelCase(parsedConfigJson, [
'override_params',
'params',
'vertex_service_account_json',
]) as any;
}

Expand Down
10 changes: 6 additions & 4 deletions src/middlewares/requestValidator/schema/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export const configSchema: any = z
// Google Vertex AI specific
vertex_project_id: z.string().optional(),
vertex_region: z.string().optional(),
vertex_service_account_json: z.object({}).catchall(z.string()).optional(),
// OpenAI specific
openai_project: z.string().optional(),
openai_organization: z.string().optional(),
Expand All @@ -75,8 +76,8 @@ export const configSchema: any = z
const isOllamaProvider = value.provider === OLLAMA;
const isVertexAIProvider =
value.provider === GOOGLE_VERTEX_AI &&
value.vertex_project_id &&
value.vertex_region;
value.vertex_region &&
(value.vertex_service_account_json || value.vertex_project_id);
const hasAWSDetails =
value.aws_access_key_id && value.aws_secret_access_key;

Expand Down Expand Up @@ -113,10 +114,11 @@ export const configSchema: any = z
(value) => {
const isGoogleVertexAIProvider = value.provider === GOOGLE_VERTEX_AI;
const hasGoogleVertexAIFields =
value.vertex_project_id && value.vertex_region;
(value.vertex_project_id && value.vertex_region) ||
(value.vertex_region && value.vertex_service_account_json);
return !(isGoogleVertexAIProvider && !hasGoogleVertexAIFields);
},
{
message: `Invalid configuration. 'vertex_project_id' and 'vertex_region' are required for '${GOOGLE_VERTEX_AI}' provider. Example: { 'provider': 'vertex-ai', 'vertex_project_id': 'my-project-id', 'vertex_region': 'us-central1', api_key: 'ya29...' }`,
message: `Invalid configuration. ('vertex_project_id' and 'vertex_region') or ('vertex_service_account_json' and 'vertex_region') are required for '${GOOGLE_VERTEX_AI}' provider. Example: { 'provider': 'vertex-ai', 'vertex_project_id': 'my-project-id', 'vertex_region': 'us-central1', api_key: 'ya29...' }`,
}
);
4 changes: 2 additions & 2 deletions src/providers/anthropic/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ export interface AnthropicErrorResponse {
error: AnthropicErrorObject;
}

interface AnthropicChatCompleteResponse {
export interface AnthropicChatCompleteResponse {
id: string;
type: string;
role: string;
Expand All @@ -166,7 +166,7 @@ interface AnthropicChatCompleteResponse {
};
}

interface AnthropicChatCompleteStreamResponse {
export interface AnthropicChatCompleteStreamResponse {
type: string;
index: number;
delta: {
Expand Down
49 changes: 37 additions & 12 deletions src/providers/google-vertex-ai/api.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,63 @@
import { ProviderAPIConfig } from '../types';
import { getModelAndProvider } from './utils';
import { getAccessToken } from './utils';

// Good reference for using REST: https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstarts/quickstart-multimodal#gemini-beginner-samples-drest
// Difference versus Studio AI: https://cloud.google.com/vertex-ai/docs/start/ai-platform-users
export const GoogleApiConfig: ProviderAPIConfig = {
getBaseURL: ({ providerOptions }) => {
const { vertexProjectId, vertexRegion } = providerOptions;
const {
vertexProjectId: inputProjectId,
vertexRegion,
vertexServiceAccountJson,
} = providerOptions;
let projectId = inputProjectId;
if (vertexServiceAccountJson) {
projectId = vertexServiceAccountJson.project_id;
}

return `https://${vertexRegion}-aiplatform.googleapis.com/v1/projects/${vertexProjectId}/locations/${vertexRegion}/publishers/google`;
return `https://${vertexRegion}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${vertexRegion}`;
},
headers: ({ providerOptions }) => {
const { apiKey } = providerOptions;
headers: async ({ providerOptions }) => {
const { apiKey, vertexServiceAccountJson } = providerOptions;
let authToken = apiKey;
if (vertexServiceAccountJson) {
authToken = await getAccessToken(vertexServiceAccountJson);
}

return {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
Authorization: `Bearer ${authToken}`,
};
},
getEndpoint: ({ fn, gatewayRequestBody }) => {
let mappedFn = fn;
const { model, stream } = gatewayRequestBody;
const { model: inputModel, stream } = gatewayRequestBody;
if (stream) {
mappedFn = `stream-${fn}`;
}
switch (mappedFn) {
case 'chatComplete': {
return `/models/${model}:generateContent`;

const { provider, model } = getModelAndProvider(inputModel as string);

switch (provider) {
case 'google': {
if (mappedFn === 'chatComplete') {
return `/publishers/${provider}/models/${model}:generateContent`;
} else if (mappedFn === 'stream-chatComplete') {
return `/publishers/${provider}/models/${model}/streamGenerateContent?alt=sse`;
}
}
case 'stream-chatComplete': {
return `/models/${model}:streamGenerateContent?alt=sse`;

case 'anthropic': {
if (mappedFn === 'chatComplete') {
return `/publishers/${provider}/models/${model}:rawPredict`;
} else if (mappedFn === 'stream-chatComplete') {
return `/publishers/${provider}/models/${model}:rawPredict`;
}
}

// Embed API is not yet implemented in the gateway
// This may be as easy as copy-paste from Google provider, but needs to be tested

default:
return '';
}
Expand Down
Loading