@@ -10,7 +10,7 @@ import {
1010} from './chatComplete' ;
1111import { Options } from '../../types/requestBody' ;
1212import { GatewayError } from '../../errors/GatewayError' ;
13- import { BedrockFinetuneRecord } from './types' ;
13+ import { BedrockFinetuneRecord , BedrockInferenceProfile } from './types' ;
1414import { FinetuneRequest } from '../types' ;
1515
1616export const generateAWSHeaders = async (
@@ -404,3 +404,80 @@ export const populateHyperParameters = (value: FinetuneRequest) => {
404404
405405 return hyperParameters ;
406406} ;
407+
408+ export const getInferenceProfile = async (
409+ inferenceProfileIdentifier : string ,
410+ awsRegion : string ,
411+ awsAccessKeyId : string ,
412+ awsSecretAccessKey : string ,
413+ awsSessionToken ?: string
414+ ) => {
415+ const url = `https://bedrock.${ awsRegion } .amazonaws.com/inference-profiles/${ encodeURIComponent ( decodeURIComponent ( inferenceProfileIdentifier ) ) } ` ;
416+
417+ const headers = await generateAWSHeaders (
418+ undefined ,
419+ { 'content-type' : 'application/json' } ,
420+ url ,
421+ 'GET' ,
422+ 'bedrock' ,
423+ awsRegion ,
424+ awsAccessKeyId ,
425+ awsSecretAccessKey ,
426+ awsSessionToken
427+ ) ;
428+
429+ try {
430+ const response = await fetch ( url , {
431+ method : 'GET' ,
432+ headers,
433+ } ) ;
434+
435+ if ( ! response . ok ) {
436+ throw new Error (
437+ `Failed to get inference profile: ${ response . status } ${ response . statusText } `
438+ ) ;
439+ }
440+
441+ return ( await response . json ( ) ) as BedrockInferenceProfile ;
442+ } catch ( error ) {
443+ console . error ( 'Error getting inference profile:' , error ) ;
444+ throw error ;
445+ }
446+ } ;
447+
448+ export const getFoundationModelFromInferenceProfile = async (
449+ c : Context ,
450+ inferenceProfileIdentifier : string ,
451+ providerOptions : Options
452+ ) => {
453+ try {
454+ const getFromCacheByKey = c . get ( 'getFromCacheByKey' ) ;
455+ const putInCacheWithValue = c . get ( 'putInCacheWithValue' ) ;
456+ const cacheKey = `bedrock-inference-profile-${ inferenceProfileIdentifier } ` ;
457+ const cachedFoundationModel = getFromCacheByKey
458+ ? await getFromCacheByKey ( env ( c ) , cacheKey )
459+ : null ;
460+ if ( cachedFoundationModel ) {
461+ return cachedFoundationModel ;
462+ }
463+
464+ const inferenceProfile = await getInferenceProfile (
465+ inferenceProfileIdentifier || '' ,
466+ providerOptions . awsRegion || '' ,
467+ providerOptions . awsAccessKeyId || '' ,
468+ providerOptions . awsSecretAccessKey || '' ,
469+ providerOptions . awsSessionToken || ''
470+ ) ;
471+
472+ // modelArn is always like arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-v2:1
473+ const foundationModel = inferenceProfile ?. models ?. [ 0 ] ?. modelArn
474+ ?. split ( '/' )
475+ ?. pop ( ) ;
476+ if ( putInCacheWithValue ) {
477+ putInCacheWithValue ( env ( c ) , cacheKey , foundationModel , 86400 ) ;
478+ }
479+ return foundationModel ;
480+ } catch ( error ) {
481+ return null ;
482+ }
483+ } ;
0 commit comments