Skip to content

Commit c08b44f

Browse files
authored
Merge pull request #977 from narengogi/bedrock/kms-key-support-for-put-requests
feat: kms key support for put requests in bedrock
2 parents 16828d2 + 62823fc commit c08b44f

File tree

6 files changed

+65
-19
lines changed

6 files changed

+65
-19
lines changed

src/handlers/handlerUtils.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ export async function tryPost(
368368
params,
369369
requestBody,
370370
fn,
371-
requestHeaders
371+
requestHeaders,
372+
providerOption
372373
)
373374
: requestBody;
374375
}
@@ -388,7 +389,8 @@ export async function tryPost(
388389
params,
389390
requestBody,
390391
fn,
391-
requestHeaders
392+
requestHeaders,
393+
providerOption
392394
)
393395
: requestBody;
394396
}
@@ -954,6 +956,12 @@ export function constructConfigFromRequestHeaders(
954956
awsBedrockModel:
955957
requestHeaders[`x-${POWERED_BY}-aws-bedrock-model`] ||
956958
requestHeaders[`x-${POWERED_BY}-provider-model`],
959+
awsServerSideEncryption:
960+
requestHeaders[`x-${POWERED_BY}-amz-server-side-encryption`],
961+
awsServerSideEncryptionKMSKeyId:
962+
requestHeaders[
963+
`x-${POWERED_BY}-amz-server-side-encryption-aws-kms-key-id`
964+
],
957965
};
958966

959967
const sagemakerConfig = {

src/providers/bedrock/api.ts

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,13 @@ const BEDROCK_FINETUNE_ENDPOINTS: endpointStrings[] = [
5353
'cancelFinetune',
5454
];
5555

56-
const S3_ENDPOINTS: endpointStrings[] = [
56+
const ENDPOINTS_TO_ROUTE_TO_S3 = [
5757
'retrieveFileContent',
5858
'getBatchOutput',
5959
'retrieveFile',
6060
'retrieveFileContent',
6161
'uploadFile',
62+
'initiateMultipartUpload',
6263
];
6364

6465
const getMethod = (fn: endpointStrings, transformedRequestUrl: string) => {
@@ -69,6 +70,33 @@ const getMethod = (fn: endpointStrings, transformedRequestUrl: string) => {
6970
return AWS_GET_METHODS.includes(fn as endpointStrings) ? 'GET' : 'POST';
7071
};
7172

73+
const getService = (fn: endpointStrings) => {
74+
return ENDPOINTS_TO_ROUTE_TO_S3.includes(fn as endpointStrings)
75+
? 's3'
76+
: 'bedrock';
77+
};
78+
79+
const setRouteSpecificHeaders = (
80+
fn: string,
81+
headers: Record<string, string>,
82+
providerOptions: Options
83+
) => {
84+
if (fn === 'retrieveFile') {
85+
headers['x-amz-object-attributes'] = 'ObjectSize';
86+
}
87+
if (fn === 'initiateMultipartUpload') {
88+
if (providerOptions.awsServerSideEncryptionKMSKeyId) {
89+
headers['x-amz-server-side-encryption-aws-kms-key-id'] =
90+
providerOptions.awsServerSideEncryptionKMSKeyId;
91+
headers['x-amz-server-side-encryption'] = 'aws:kms';
92+
}
93+
if (providerOptions.awsServerSideEncryption) {
94+
headers['x-amz-server-side-encryption'] =
95+
providerOptions.awsServerSideEncryption;
96+
}
97+
}
98+
};
99+
72100
const BedrockAPIConfig: BedrockAPIConfigInterface = {
73101
getBaseURL: ({ providerOptions, fn, gatewayRequestURL }) => {
74102
if (fn === 'retrieveFile') {
@@ -99,6 +127,7 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
99127
transformedRequestUrl,
100128
}) => {
101129
const method = getMethod(fn as endpointStrings, transformedRequestUrl);
130+
const service = getService(fn as endpointStrings);
102131

103132
const headers: Record<string, string> = {
104133
'content-type': 'application/json',
@@ -107,18 +136,13 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
107136
if (method === 'PUT' || method === 'GET') {
108137
delete headers['content-type'];
109138
}
110-
if (fn === 'retrieveFile') {
111-
headers['x-amz-object-attributes'] = 'ObjectSize';
112-
}
139+
140+
setRouteSpecificHeaders(fn, headers, providerOptions);
113141

114142
if (providerOptions.awsAuthType === 'assumedRole') {
115143
await providerAssumedRoleCredentials(c, providerOptions);
116144
}
117145

118-
const service = S3_ENDPOINTS.includes(fn as endpointStrings)
119-
? 's3'
120-
: 'bedrock';
121-
122146
let finalRequestBody = transformedRequestBody;
123147

124148
if (['cancelFinetune', 'cancelBatch'].includes(fn as endpointStrings)) {

src/providers/bedrock/createBatch.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { BEDROCK } from '../../globals';
2+
import { Options } from '../../types/requestBody';
23
import {
34
CreateBatchRequest,
45
CreateBatchResponse,
@@ -43,13 +44,16 @@ export const BedrockCreateBatchConfig: ProviderConfig = {
4344
output_data_config: {
4445
param: 'outputDataConfig',
4546
required: true,
46-
default: (params: BedrockCreateBatchRequest) => {
47+
default: (params: BedrockCreateBatchRequest, providerOptions: Options) => {
4748
const inputFileId = decodeURIComponent(params.input_file_id);
4849
const s3URLToContainingFolder =
4950
inputFileId.split('/').slice(0, -1).join('/') + '/';
5051
return {
5152
s3OutputDataConfig: {
5253
s3Uri: s3URLToContainingFolder,
54+
...(providerOptions.awsServerSideEncryptionKMSKeyId && {
55+
s3EncryptionKeyId: providerOptions.awsServerSideEncryptionKMSKeyId,
56+
}),
5357
},
5458
};
5559
},

src/providers/bedrock/uploadFile.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class AwsMultipartUploadHandler {
4545
const headers = await BedrockAPIConfig.headers({
4646
c: this.c,
4747
providerOptions: this.providerOptions,
48-
fn: 'uploadFile',
48+
fn: 'initiateMultipartUpload',
4949
transformedRequestBody: {},
5050
transformedRequestUrl: this.url.toString(),
5151
});

src/services/transformToProviderRequest.ts

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { GatewayError } from '../errors/GatewayError';
22
import ProviderConfigs from '../providers';
33
import { endpointStrings, ProviderConfig } from '../providers/types';
4-
import { Params } from '../types/requestBody';
4+
import { Options, Params } from '../types/requestBody';
55

66
/**
77
* Helper function to set a nested property in an object.
@@ -66,7 +66,8 @@ const getValue = (configParam: string, params: Params, paramConfig: any) => {
6666

6767
export const transformUsingProviderConfig = (
6868
providerConfig: ProviderConfig,
69-
params: Params
69+
params: Params,
70+
providerOptions?: Options
7071
) => {
7172
const transformedRequest: { [key: string]: any } = {};
7273

@@ -99,7 +100,7 @@ export const transformUsingProviderConfig = (
99100
) {
100101
let value;
101102
if (typeof paramConfig.default === 'function') {
102-
value = paramConfig.default(params);
103+
value = paramConfig.default(params, providerOptions);
103104
} else {
104105
value = paramConfig.default;
105106
}
@@ -129,7 +130,8 @@ export const transformUsingProviderConfig = (
129130
const transformToProviderRequestJSON = (
130131
provider: string,
131132
params: Params,
132-
fn: string
133+
fn: string,
134+
providerOptions: Options
133135
): { [key: string]: any } => {
134136
// Get the configuration for the specified provider
135137
let providerConfig = ProviderConfigs[provider];
@@ -143,7 +145,7 @@ const transformToProviderRequestJSON = (
143145
throw new GatewayError(`${fn} is not supported by ${provider}`);
144146
}
145147

146-
return transformUsingProviderConfig(providerConfig, params);
148+
return transformUsingProviderConfig(providerConfig, params, providerOptions);
147149
};
148150

149151
const transformToProviderRequestFormData = (
@@ -218,7 +220,8 @@ export const transformToProviderRequest = (
218220
params: Params,
219221
requestBody: Params | FormData | ArrayBuffer | ReadableStream | ArrayBuffer,
220222
fn: endpointStrings,
221-
requestHeaders: Record<string, string>
223+
requestHeaders: Record<string, string>,
224+
providerOptions: Options
222225
) => {
223226
// this returns a ReadableStream
224227
if (fn === 'uploadFile') {
@@ -242,7 +245,12 @@ export const transformToProviderRequest = (
242245
providerAPIConfig.transformToFormData({ gatewayRequestBody: params })
243246
)
244247
return transformToProviderRequestFormData(provider, params as Params, fn);
245-
return transformToProviderRequestJSON(provider, params as Params, fn);
248+
return transformToProviderRequestJSON(
249+
provider,
250+
params as Params,
251+
fn,
252+
providerOptions
253+
);
246254
};
247255

248256
export default transformToProviderRequest;

src/types/requestBody.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ export interface Options {
9090
awsS3Bucket?: string;
9191
awsS3ObjectKey?: string;
9292
awsBedrockModel?: string;
93+
awsServerSideEncryption?: string;
94+
awsServerSideEncryptionKMSKeyId?: string;
9395

9496
/** Sagemaker specific */
9597
amznSagemakerCustomAttributes?: string;

0 commit comments

Comments
 (0)