Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ export async function tryPost(
params,
requestBody,
fn,
requestHeaders
requestHeaders,
providerOption
)
: requestBody;
}
Expand All @@ -388,7 +389,8 @@ export async function tryPost(
params,
requestBody,
fn,
requestHeaders
requestHeaders,
providerOption
)
: requestBody;
}
Expand Down Expand Up @@ -954,6 +956,12 @@ export function constructConfigFromRequestHeaders(
awsBedrockModel:
requestHeaders[`x-${POWERED_BY}-aws-bedrock-model`] ||
requestHeaders[`x-${POWERED_BY}-provider-model`],
awsServerSideEncryption:
requestHeaders[`x-${POWERED_BY}-amz-server-side-encryption`],
awsServerSideEncryptionKMSKeyId:
requestHeaders[
`x-${POWERED_BY}-amz-server-side-encryption-aws-kms-key-id`
],
};

const sagemakerConfig = {
Expand Down
40 changes: 32 additions & 8 deletions src/providers/bedrock/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ const BEDROCK_FINETUNE_ENDPOINTS: endpointStrings[] = [
'cancelFinetune',
];

const S3_ENDPOINTS: endpointStrings[] = [
const ENDPOINTS_TO_ROUTE_TO_S3 = [
'retrieveFileContent',
'getBatchOutput',
'retrieveFile',
'retrieveFileContent',
'uploadFile',
'initiateMultipartUpload',
];

const getMethod = (fn: endpointStrings, transformedRequestUrl: string) => {
Expand All @@ -69,6 +70,33 @@ const getMethod = (fn: endpointStrings, transformedRequestUrl: string) => {
return AWS_GET_METHODS.includes(fn as endpointStrings) ? 'GET' : 'POST';
};

const getService = (fn: endpointStrings) => {
return ENDPOINTS_TO_ROUTE_TO_S3.includes(fn as endpointStrings)
? 's3'
: 'bedrock';
};

const setRouteSpecificHeaders = (
fn: string,
headers: Record<string, string>,
providerOptions: Options
) => {
if (fn === 'retrieveFile') {
headers['x-amz-object-attributes'] = 'ObjectSize';
}
if (fn === 'initiateMultipartUpload') {
if (providerOptions.awsServerSideEncryptionKMSKeyId) {
headers['x-amz-server-side-encryption-aws-kms-key-id'] =
providerOptions.awsServerSideEncryptionKMSKeyId;
headers['x-amz-server-side-encryption'] = 'aws:kms';
}
if (providerOptions.awsServerSideEncryption) {
headers['x-amz-server-side-encryption'] =
providerOptions.awsServerSideEncryption;
}
}
};

const BedrockAPIConfig: BedrockAPIConfigInterface = {
getBaseURL: ({ providerOptions, fn, gatewayRequestURL }) => {
if (fn === 'retrieveFile') {
Expand Down Expand Up @@ -99,6 +127,7 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
transformedRequestUrl,
}) => {
const method = getMethod(fn as endpointStrings, transformedRequestUrl);
const service = getService(fn as endpointStrings);

const headers: Record<string, string> = {
'content-type': 'application/json',
Expand All @@ -107,18 +136,13 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
if (method === 'PUT' || method === 'GET') {
delete headers['content-type'];
}
if (fn === 'retrieveFile') {
headers['x-amz-object-attributes'] = 'ObjectSize';
}

setRouteSpecificHeaders(fn, headers, providerOptions);

if (providerOptions.awsAuthType === 'assumedRole') {
await providerAssumedRoleCredentials(c, providerOptions);
}

const service = S3_ENDPOINTS.includes(fn as endpointStrings)
? 's3'
: 'bedrock';

let finalRequestBody = transformedRequestBody;

if (['cancelFinetune', 'cancelBatch'].includes(fn as endpointStrings)) {
Expand Down
6 changes: 5 additions & 1 deletion src/providers/bedrock/createBatch.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { BEDROCK } from '../../globals';
import { Options } from '../../types/requestBody';
import {
CreateBatchRequest,
CreateBatchResponse,
Expand Down Expand Up @@ -43,13 +44,16 @@ export const BedrockCreateBatchConfig: ProviderConfig = {
output_data_config: {
param: 'outputDataConfig',
required: true,
default: (params: BedrockCreateBatchRequest) => {
default: (params: BedrockCreateBatchRequest, providerOptions: Options) => {
const inputFileId = decodeURIComponent(params.input_file_id);
const s3URLToContainingFolder =
inputFileId.split('/').slice(0, -1).join('/') + '/';
return {
s3OutputDataConfig: {
s3Uri: s3URLToContainingFolder,
...(providerOptions.awsServerSideEncryptionKMSKeyId && {
s3EncryptionKeyId: providerOptions.awsServerSideEncryptionKMSKeyId,
}),
},
};
},
Expand Down
2 changes: 1 addition & 1 deletion src/providers/bedrock/uploadFile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AwsMultipartUploadHandler {
const headers = await BedrockAPIConfig.headers({
c: this.c,
providerOptions: this.providerOptions,
fn: 'uploadFile',
fn: 'initiateMultipartUpload',
transformedRequestBody: {},
transformedRequestUrl: this.url.toString(),
});
Expand Down
22 changes: 15 additions & 7 deletions src/services/transformToProviderRequest.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { GatewayError } from '../errors/GatewayError';
import ProviderConfigs from '../providers';
import { endpointStrings, ProviderConfig } from '../providers/types';
import { Params } from '../types/requestBody';
import { Options, Params } from '../types/requestBody';

/**
* Helper function to set a nested property in an object.
Expand Down Expand Up @@ -66,7 +66,8 @@ const getValue = (configParam: string, params: Params, paramConfig: any) => {

export const transformUsingProviderConfig = (
providerConfig: ProviderConfig,
params: Params
params: Params,
providerOptions?: Options
) => {
const transformedRequest: { [key: string]: any } = {};

Expand Down Expand Up @@ -99,7 +100,7 @@ export const transformUsingProviderConfig = (
) {
let value;
if (typeof paramConfig.default === 'function') {
value = paramConfig.default(params);
value = paramConfig.default(params, providerOptions);
} else {
value = paramConfig.default;
}
Expand Down Expand Up @@ -129,7 +130,8 @@ export const transformUsingProviderConfig = (
const transformToProviderRequestJSON = (
provider: string,
params: Params,
fn: string
fn: string,
providerOptions: Options
): { [key: string]: any } => {
// Get the configuration for the specified provider
let providerConfig = ProviderConfigs[provider];
Expand All @@ -143,7 +145,7 @@ const transformToProviderRequestJSON = (
throw new GatewayError(`${fn} is not supported by ${provider}`);
}

return transformUsingProviderConfig(providerConfig, params);
return transformUsingProviderConfig(providerConfig, params, providerOptions);
};

const transformToProviderRequestFormData = (
Expand Down Expand Up @@ -218,7 +220,8 @@ export const transformToProviderRequest = (
params: Params,
requestBody: Params | FormData | ArrayBuffer | ReadableStream | ArrayBuffer,
fn: endpointStrings,
requestHeaders: Record<string, string>
requestHeaders: Record<string, string>,
providerOptions: Options
) => {
// this returns a ReadableStream
if (fn === 'uploadFile') {
Expand All @@ -242,7 +245,12 @@ export const transformToProviderRequest = (
providerAPIConfig.transformToFormData({ gatewayRequestBody: params })
)
return transformToProviderRequestFormData(provider, params as Params, fn);
return transformToProviderRequestJSON(provider, params as Params, fn);
return transformToProviderRequestJSON(
provider,
params as Params,
fn,
providerOptions
);
};

export default transformToProviderRequest;
2 changes: 2 additions & 0 deletions src/types/requestBody.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ export interface Options {
awsS3Bucket?: string;
awsS3ObjectKey?: string;
awsBedrockModel?: string;
awsServerSideEncryption?: string;
awsServerSideEncryptionKMSKeyId?: string;

/** Sagemaker specific */
amznSagemakerCustomAttributes?: string;
Expand Down