Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/handlers/responseHandlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import { anthropicMessagesJsonToStreamGenerator } from '../providers/anthropic-b
export async function responseHandler(
response: Response,
streamingMode: boolean,
provider: string | Options,
providerOptions: Options,
responseTransformer: string | undefined,
requestURL: string,
isCacheHit: boolean = false,
Expand All @@ -53,17 +53,16 @@ export async function responseHandler(
let responseTransformerFunction: Function | undefined;
const responseContentType = response.headers?.get('content-type');
const isSuccessStatusCode = [200, 246].includes(response.status);

if (typeof provider == 'object') {
provider = provider.provider || '';
}
const provider = providerOptions.provider;

const providerConfig = Providers[provider];
let providerTransformers = Providers[provider]?.responseTransforms;

if (providerConfig?.getConfig) {
providerTransformers =
providerConfig.getConfig(gatewayRequest).responseTransforms;
providerTransformers = providerConfig.getConfig({
params: gatewayRequest,
providerOptions,
}).responseTransforms;
}

// Checking status 200 so that errors are not considered as stream mode.
Expand Down
2 changes: 1 addition & 1 deletion src/handlers/services/responseService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export class ResponseService {
return await responseHandler(
response,
this.context.isStreaming,
this.context.provider,
this.context.providerOption,
responseTransformer,
url,
isCacheHit,
Expand Down
2 changes: 1 addition & 1 deletion src/providers/bedrock/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
providerOptions
);
if (foundationModel) {
params.foundationModel = foundationModel;
providerOptions.foundationModel = foundationModel;
}
}
if (fn === 'retrieveFile') {
Expand Down
5 changes: 2 additions & 3 deletions src/providers/bedrock/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { AI21, ANTHROPIC, COHERE } from '../../globals';
import { Params } from '../../types/requestBody';
import { ProviderConfigs } from '../types';
import BedrockAPIConfig from './api';
import { BedrockCancelBatchResponseTransform } from './cancelBatch';
Expand Down Expand Up @@ -90,13 +89,13 @@ const BedrockConfig: ProviderConfigs = {
getBatchOutput: BedrockGetBatchOutputRequestHandler,
retrieveFileContent: BedrockRetrieveFileContentRequestHandler,
},
getConfig: (params: Params) => {
getConfig: ({ params, providerOptions }) => {
// To remove the region in case its a cross-region inference profile ID
// https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-support.html
let config: ProviderConfigs = {};

if (params.model) {
let providerModel = params.foundationModel || params.model;
let providerModel = providerOptions.foundationModel || params.model;
providerModel = providerModel.replace(/^(us\.|eu\.)/, '');
const providerModelArray = providerModel?.split('.');
const provider = providerModelArray?.[0];
Expand Down
2 changes: 1 addition & 1 deletion src/providers/google-vertex-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import {

const VertexConfig: ProviderConfigs = {
api: VertexApiConfig,
getConfig: (params: Params) => {
getConfig: ({ params }) => {
const requestConfig = {
uploadFile: {},
createBatch: GoogleBatchCreateConfig,
Expand Down
3 changes: 1 addition & 2 deletions src/providers/stability-ai/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { ProviderConfigs } from '../types';
import StabilityAIAPIConfig from './api';
import { STABILITY_V1_MODELS } from './constants';
import {
StabilityAIImageGenerateV1Config,
StabilityAIImageGenerateV1ResponseTransform,
Expand All @@ -13,7 +12,7 @@ import { isStabilityV1Model } from './utils';

const StabilityAIConfig: ProviderConfigs = {
api: StabilityAIAPIConfig,
getConfig: (params: Params) => {
getConfig: ({ params }) => {
const model = params.model;
if (typeof model === 'string' && isStabilityV1Model(model)) {
return {
Expand Down
7 changes: 7 additions & 0 deletions src/providers/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ export interface ProviderConfigs {
/** The configuration for each provider, indexed by provider name. */
[key: string]: any;
requestHandlers?: RequestHandlers;
getConfig?: ({
params,
providerOptions,
}: {
params: Params;
providerOptions: Options;
}) => any;
}

export interface BaseResponse {
Expand Down
41 changes: 28 additions & 13 deletions src/services/transformToProviderRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import ProviderConfigs from '../providers';
import { endpointStrings, ProviderConfig } from '../providers/types';
import { Options, Params } from '../types/requestBody';

// TODO: Refactor this file to use the providerOptions object instead of the provider string

/**
* Helper function to set a nested property in an object.
*
Expand Down Expand Up @@ -68,7 +70,7 @@ const getValue = (configParam: string, params: Params, paramConfig: any) => {
export const transformUsingProviderConfig = (
providerConfig: ProviderConfig,
params: Params,
providerOptions?: Options
providerOptions: Options
) => {
const transformedRequest: { [key: string]: any } = {};

Expand Down Expand Up @@ -137,7 +139,7 @@ const transformToProviderRequestJSON = (
// Get the configuration for the specified provider
let providerConfig = ProviderConfigs[provider];
if (providerConfig.getConfig) {
providerConfig = providerConfig.getConfig(params)[fn];
providerConfig = providerConfig.getConfig({ params, providerOptions })[fn];
} else {
providerConfig = providerConfig[fn];
}
Expand All @@ -152,11 +154,12 @@ const transformToProviderRequestJSON = (
const transformToProviderRequestFormData = (
provider: string,
params: Params,
fn: string
fn: string,
providerOptions: Options
): FormData => {
let providerConfig = ProviderConfigs[provider];
if (providerConfig.getConfig) {
providerConfig = providerConfig.getConfig(params)[fn];
providerConfig = providerConfig.getConfig({ params, providerOptions })[fn];
} else {
providerConfig = providerConfig[fn];
}
Expand Down Expand Up @@ -193,18 +196,23 @@ const transformToProviderRequestBody = (
provider: string,
requestBody: ReadableStream,
requestHeaders: Record<string, string>,
providerOptions: Options,
fn: string
) => {
if (ProviderConfigs[provider].getConfig) {
return ProviderConfigs[provider]
.getConfig({}, fn)
.requestTransforms[fn](requestBody, requestHeaders);
let providerConfig = ProviderConfigs[provider];
if (providerConfig.getConfig) {
providerConfig = providerConfig.getConfig({ params: {}, providerOptions })[
fn
];
} else {
return ProviderConfigs[provider].requestTransforms[fn](
requestBody,
requestHeaders
);
providerConfig = providerConfig[fn];
}

if (!providerConfig) {
throw new GatewayError(`${fn} is not supported by ${provider}`);
}

return providerConfig.requestTransforms[fn](requestBody, requestHeaders);
};

/**
Expand All @@ -230,6 +238,7 @@ export const transformToProviderRequest = (
provider,
requestBody as ReadableStream,
requestHeaders,
providerOptions,
fn
);
}
Expand All @@ -242,6 +251,7 @@ export const transformToProviderRequest = (
provider,
requestBody as ReadableStream,
requestHeaders,
providerOptions,
fn
);
}
Expand All @@ -258,7 +268,12 @@ export const transformToProviderRequest = (
providerAPIConfig.transformToFormData &&
providerAPIConfig.transformToFormData({ gatewayRequestBody: params })
)
return transformToProviderRequestFormData(provider, params as Params, fn);
return transformToProviderRequestFormData(
provider,
params as Params,
fn,
providerOptions
);
return transformToProviderRequestJSON(
provider,
params as Params,
Expand Down
3 changes: 2 additions & 1 deletion src/types/requestBody.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ interface Strategy {
*/
export interface Options {
/** The name of the provider. */
provider: string | undefined;
provider: string;
/** The name of the API key for the provider. */
virtualKey?: string;
/** The API key for the provider. */
Expand Down Expand Up @@ -95,6 +95,7 @@ export interface Options {
awsBedrockModel?: string;
awsServerSideEncryption?: string;
awsServerSideEncryptionKMSKeyId?: string;
foundationModel?: string;

/** Sagemaker specific */
amznSagemakerCustomAttributes?: string;
Expand Down