Skip to content

Commit a6fe2d9

Browse files
authored
Merge pull request #1118 from narengogi/feat/bedrock-inference-profiles
feature: inference profiles for bedrock
2 parents fb39359 + e982ce3 commit a6fe2d9

File tree

9 files changed

+114
-5
lines changed

9 files changed

+114
-5
lines changed

src/handlers/handlerUtils.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ export async function tryPost(
323323
fn,
324324
c,
325325
gatewayRequestURL: c.req.url,
326+
params: params,
326327
}));
327328
const endpoint =
328329
fn === 'proxy'

src/providers/bedrock/api.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { bedrockInvokeModels } from './constants';
55
import {
66
generateAWSHeaders,
77
getAssumedRoleCredentials,
8+
getFoundationModelFromInferenceProfile,
89
providerAssumedRoleCredentials,
910
} from './utils';
1011
import { GatewayError } from '../../errors/GatewayError';
@@ -101,7 +102,20 @@ const setRouteSpecificHeaders = (
101102
};
102103

103104
const BedrockAPIConfig: BedrockAPIConfigInterface = {
104-
getBaseURL: ({ providerOptions, fn, gatewayRequestURL }) => {
105+
getBaseURL: async ({ c, providerOptions, fn, gatewayRequestURL, params }) => {
106+
const model = decodeURIComponent(params?.model || '');
107+
if (model.includes('arn:aws') && params) {
108+
const foundationModel = model.includes('foundation-model/')
109+
? model.split('/').pop()
110+
: await getFoundationModelFromInferenceProfile(
111+
c,
112+
model,
113+
providerOptions
114+
);
115+
if (foundationModel) {
116+
params.foundationModel = foundationModel;
117+
}
118+
}
105119
if (fn === 'retrieveFile') {
106120
const s3URL = decodeURIComponent(
107121
gatewayRequestURL.split('/v1/files/')[1]

src/providers/bedrock/getBatchOutput.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ export const BedrockGetBatchOutputRequestHandler = async ({
5353
// get s3 file id from batch details
5454
// get file from s3
5555
// return file
56-
const baseUrl = BedrockAPIConfig.getBaseURL({
56+
const baseUrl = await BedrockAPIConfig.getBaseURL({
5757
providerOptions,
5858
fn: 'retrieveBatch',
5959
c,

src/providers/bedrock/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ const BedrockConfig: ProviderConfigs = {
8989
let config: ProviderConfigs = {};
9090

9191
if (params.model) {
92-
const providerModel = params?.model?.replace(/^(us\.|eu\.)/, '');
92+
let providerModel = params.foundationModel || params.model;
93+
providerModel = providerModel.replace(/^(us\.|eu\.)/, '');
9394
const providerModelArray = providerModel?.split('.');
9495
const provider = providerModelArray?.[0];
9596
const model = providerModelArray?.slice(1).join('.');

src/providers/bedrock/retrieveFileContent.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ export const BedrockRetrieveFileContentRequestHandler = async ({
1919
}) => {
2020
try {
2121
// construct the base url and endpoint
22-
const baseURL = BedrockAPIConfig.getBaseURL({
22+
const baseURL = await BedrockAPIConfig.getBaseURL({
2323
providerOptions,
2424
fn: 'retrieveFileContent',
2525
c,

src/providers/bedrock/types.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,17 @@ export interface BedrockFinetuneRecord {
6464
outputModelName?: string;
6565
outputModelArn?: string;
6666
}
67+
68+
export interface BedrockInferenceProfile {
69+
inferenceProfileName: string;
70+
description: string;
71+
createdAt: string;
72+
updatedAt: string;
73+
inferenceProfileArn: string;
74+
models: {
75+
modelArn: string;
76+
}[];
77+
inferenceProfileId: string;
78+
status: string;
79+
type: string;
80+
}

src/providers/bedrock/utils.ts

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import {
1010
} from './chatComplete';
1111
import { Options } from '../../types/requestBody';
1212
import { GatewayError } from '../../errors/GatewayError';
13-
import { BedrockFinetuneRecord } from './types';
13+
import { BedrockFinetuneRecord, BedrockInferenceProfile } from './types';
1414
import { FinetuneRequest } from '../types';
1515

1616
export const generateAWSHeaders = async (
@@ -404,3 +404,80 @@ export const populateHyperParameters = (value: FinetuneRequest) => {
404404

405405
return hyperParameters;
406406
};
407+
408+
export const getInferenceProfile = async (
409+
inferenceProfileIdentifier: string,
410+
awsRegion: string,
411+
awsAccessKeyId: string,
412+
awsSecretAccessKey: string,
413+
awsSessionToken?: string
414+
) => {
415+
const url = `https://bedrock.${awsRegion}.amazonaws.com/inference-profiles/${encodeURIComponent(decodeURIComponent(inferenceProfileIdentifier))}`;
416+
417+
const headers = await generateAWSHeaders(
418+
undefined,
419+
{ 'content-type': 'application/json' },
420+
url,
421+
'GET',
422+
'bedrock',
423+
awsRegion,
424+
awsAccessKeyId,
425+
awsSecretAccessKey,
426+
awsSessionToken
427+
);
428+
429+
try {
430+
const response = await fetch(url, {
431+
method: 'GET',
432+
headers,
433+
});
434+
435+
if (!response.ok) {
436+
throw new Error(
437+
`Failed to get inference profile: ${response.status} ${response.statusText}`
438+
);
439+
}
440+
441+
return (await response.json()) as BedrockInferenceProfile;
442+
} catch (error) {
443+
console.error('Error getting inference profile:', error);
444+
throw error;
445+
}
446+
};
447+
448+
export const getFoundationModelFromInferenceProfile = async (
449+
c: Context,
450+
inferenceProfileIdentifier: string,
451+
providerOptions: Options
452+
) => {
453+
try {
454+
const getFromCacheByKey = c.get('getFromCacheByKey');
455+
const putInCacheWithValue = c.get('putInCacheWithValue');
456+
const cacheKey = `bedrock-inference-profile-${inferenceProfileIdentifier}`;
457+
const cachedFoundationModel = getFromCacheByKey
458+
? await getFromCacheByKey(env(c), cacheKey)
459+
: null;
460+
if (cachedFoundationModel) {
461+
return cachedFoundationModel;
462+
}
463+
464+
const inferenceProfile = await getInferenceProfile(
465+
inferenceProfileIdentifier || '',
466+
providerOptions.awsRegion || '',
467+
providerOptions.awsAccessKeyId || '',
468+
providerOptions.awsSecretAccessKey || '',
469+
providerOptions.awsSessionToken || ''
470+
);
471+
472+
// modelArn is always like arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-v2:1
473+
const foundationModel = inferenceProfile?.models?.[0]?.modelArn
474+
?.split('/')
475+
?.pop();
476+
if (putInCacheWithValue) {
477+
putInCacheWithValue(env(c), cacheKey, foundationModel, 86400);
478+
}
479+
return foundationModel;
480+
} catch (error) {
481+
return null;
482+
}
483+
};

src/providers/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export interface ProviderAPIConfig {
5050
requestHeaders?: Record<string, string>;
5151
c: Context;
5252
gatewayRequestURL: string;
53+
params?: Params;
5354
}) => Promise<string> | string;
5455
/** A function to generate the endpoint based on parameters */
5556
getEndpoint: (args: {

src/types/requestBody.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ export interface Params {
429429
// Embeddings specific
430430
dimensions?: number;
431431
parameters?: any;
432+
[key: string]: any;
432433
}
433434

434435
interface Examples {

0 commit comments

Comments
 (0)