Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ export async function tryPost(
fn,
c,
gatewayRequestURL: c.req.url,
params: params,
}));
const endpoint =
fn === 'proxy'
Expand Down
16 changes: 15 additions & 1 deletion src/providers/bedrock/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { bedrockInvokeModels } from './constants';
import {
generateAWSHeaders,
getAssumedRoleCredentials,
getFoundationModelFromInferenceProfile,
providerAssumedRoleCredentials,
} from './utils';
import { GatewayError } from '../../errors/GatewayError';
Expand Down Expand Up @@ -101,7 +102,20 @@ const setRouteSpecificHeaders = (
};

const BedrockAPIConfig: BedrockAPIConfigInterface = {
getBaseURL: ({ providerOptions, fn, gatewayRequestURL }) => {
getBaseURL: async ({ c, providerOptions, fn, gatewayRequestURL, params }) => {
const model = decodeURIComponent(params?.model || '');
if (model.includes('arn:aws') && params) {
const foundationModel = model.includes('foundation-model/')
? model.split('/').pop()
: await getFoundationModelFromInferenceProfile(
c,
model,
providerOptions
);
if (foundationModel) {
params.foundationModel = foundationModel;
}
Comment on lines +108 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Code Refactor

Issue: The code doesn't handle the case where foundationModel extraction fails but still attempts to use it.
Fix: Add a check to ensure foundationModel is defined before setting it in params.
Impact: Prevents potential undefined values from being used in the model parameter.

Suggested change
const foundationModel = model.includes('foundation-model/')
? model.split('/').pop()
: await getFoundationModelFromInferenceProfile(
c,
model,
providerOptions
);
if (foundationModel) {
params.foundationModel = foundationModel;
}
const foundationModel = model.includes('foundation-model/')
? model.split('/').pop()
: await getFoundationModelFromInferenceProfile(
c,
model,
providerOptions
);
if (foundationModel && foundationModel.length > 0) {
params.foundationModel = foundationModel;
}

}
if (fn === 'retrieveFile') {
const s3URL = decodeURIComponent(
gatewayRequestURL.split('/v1/files/')[1]
Expand Down
2 changes: 1 addition & 1 deletion src/providers/bedrock/getBatchOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export const BedrockGetBatchOutputRequestHandler = async ({
// get s3 file id from batch details
// get file from s3
// return file
const baseUrl = BedrockAPIConfig.getBaseURL({
const baseUrl = await BedrockAPIConfig.getBaseURL({
providerOptions,
fn: 'retrieveBatch',
c,
Expand Down
3 changes: 2 additions & 1 deletion src/providers/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ const BedrockConfig: ProviderConfigs = {
let config: ProviderConfigs = {};

if (params.model) {
const providerModel = params?.model?.replace(/^(us\.|eu\.)/, '');
let providerModel = params.foundationModel || params.model;
providerModel = providerModel.replace(/^(us\.|eu\.)/, '');
const providerModelArray = providerModel?.split('.');
const provider = providerModelArray?.[0];
const model = providerModelArray?.slice(1).join('.');
Expand Down
2 changes: 1 addition & 1 deletion src/providers/bedrock/retrieveFileContent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export const BedrockRetrieveFileContentRequestHandler = async ({
}) => {
try {
// construct the base url and endpoint
const baseURL = BedrockAPIConfig.getBaseURL({
const baseURL = await BedrockAPIConfig.getBaseURL({
providerOptions,
fn: 'retrieveFileContent',
c,
Expand Down
14 changes: 14 additions & 0 deletions src/providers/bedrock/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,17 @@ export interface BedrockFinetuneRecord {
outputModelName?: string;
outputModelArn?: string;
}

export interface BedrockInferenceProfile {
inferenceProfileName: string;
description: string;
createdAt: string;
updatedAt: string;
inferenceProfileArn: string;
models: {
modelArn: string;
}[];
inferenceProfileId: string;
status: string;
type: string;
}
79 changes: 78 additions & 1 deletion src/providers/bedrock/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
} from './chatComplete';
import { Options } from '../../types/requestBody';
import { GatewayError } from '../../errors/GatewayError';
import { BedrockFinetuneRecord } from './types';
import { BedrockFinetuneRecord, BedrockInferenceProfile } from './types';
import { FinetuneRequest } from '../types';

export const generateAWSHeaders = async (
Expand Down Expand Up @@ -404,3 +404,80 @@ export const populateHyperParameters = (value: FinetuneRequest) => {

return hyperParameters;
};

export const getInferenceProfile = async (
inferenceProfileIdentifier: string,
awsRegion: string,
awsAccessKeyId: string,
awsSecretAccessKey: string,
awsSessionToken?: string
) => {
const url = `https://bedrock.${awsRegion}.amazonaws.com/inference-profiles/${encodeURIComponent(decodeURIComponent(inferenceProfileIdentifier))}`;
Comment on lines +408 to +415
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔒 Security Issue Fix

Issue: The getInferenceProfile function doesn't validate the inferenceProfileIdentifier before using it in the URL, which could potentially lead to URL manipulation issues.
Fix: Add validation to ensure the inferenceProfileIdentifier is a valid ARN format before using it.
Impact: Prevents potential security issues related to URL manipulation.

Suggested change
export const getInferenceProfile = async (
inferenceProfileIdentifier: string,
awsRegion: string,
awsAccessKeyId: string,
awsSecretAccessKey: string,
awsSessionToken?: string
) => {
const url = `https://bedrock.${awsRegion}.amazonaws.com/inference-profiles/${encodeURIComponent(decodeURIComponent(inferenceProfileIdentifier))}`;
export const getInferenceProfile = async (
inferenceProfileIdentifier: string,
awsRegion: string,
awsAccessKeyId: string,
awsSecretAccessKey: string,
awsSessionToken?: string
) => {
if (!inferenceProfileIdentifier || !inferenceProfileIdentifier.startsWith('arn:aws')) {
throw new Error('Invalid inference profile identifier format');
}
const url = `https://bedrock.${awsRegion}.amazonaws.com/inference-profiles/${encodeURIComponent(decodeURIComponent(inferenceProfileIdentifier))}`;


const headers = await generateAWSHeaders(
undefined,
{ 'content-type': 'application/json' },
url,
'GET',
'bedrock',
awsRegion,
awsAccessKeyId,
awsSecretAccessKey,
awsSessionToken
);

try {
const response = await fetch(url, {
method: 'GET',
headers,
});

if (!response.ok) {
throw new Error(
`Failed to get inference profile: ${response.status} ${response.statusText}`
);
}

return (await response.json()) as BedrockInferenceProfile;
} catch (error) {
console.error('Error getting inference profile:', error);
throw error;
}
};

export const getFoundationModelFromInferenceProfile = async (
c: Context,
inferenceProfileIdentifier: string,
providerOptions: Options
) => {
try {
const getFromCacheByKey = c.get('getFromCacheByKey');
const putInCacheWithValue = c.get('putInCacheWithValue');
const cacheKey = `bedrock-inference-profile-${inferenceProfileIdentifier}`;
const cachedFoundationModel = getFromCacheByKey
? await getFromCacheByKey(env(c), cacheKey)
: null;
if (cachedFoundationModel) {
return cachedFoundationModel;
}

const inferenceProfile = await getInferenceProfile(
inferenceProfileIdentifier || '',
providerOptions.awsRegion || '',
providerOptions.awsAccessKeyId || '',
providerOptions.awsSecretAccessKey || '',
providerOptions.awsSessionToken || ''
);

// modelArn is always like arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-v2:1
const foundationModel = inferenceProfile?.models?.[0]?.modelArn
?.split('/')
?.pop();
if (putInCacheWithValue) {
putInCacheWithValue(env(c), cacheKey, foundationModel, 86400);
}
return foundationModel;
} catch (error) {
return null;
}
};
1 change: 1 addition & 0 deletions src/providers/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export interface ProviderAPIConfig {
requestHeaders?: Record<string, string>;
c: Context;
gatewayRequestURL: string;
params?: Params;
}) => Promise<string> | string;
/** A function to generate the endpoint based on parameters */
getEndpoint: (args: {
Expand Down
1 change: 1 addition & 0 deletions src/types/requestBody.ts
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ export interface Params {
// Embeddings specific
dimensions?: number;
parameters?: any;
[key: string]: any;
}

interface Examples {
Expand Down