Skip to content

Commit e6c1b8b

Browse files
authored
[Serving] Implement SageMaker Secure Mode & support for multiple data sources (#2042)
1 parent eae8c78 commit e6c1b8b

File tree

14 files changed

+837
-2
lines changed

14 files changed

+837
-2
lines changed

gradle/libs.versions.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ httpcomponents = "4.5.14"
1717

1818
testng = "7.10.2"
1919
junit = "4.13.2"
20+
mockitocore = "5.12.0"
2021

2122
[libraries]
2223
huggingface-tokenizers = { module = "ai.djl.huggingface:tokenizers" }
@@ -41,3 +42,4 @@ prometheus-exposition-formats = { module = "io.prometheus:prometheus-metrics-exp
4142

4243
testng = { module = "org.testng:testng", version.ref = "testng" }
4344
junit = { module = "junit:junit", version.ref = "junit" }
45+
mockito-core = { module = "org.mockito:mockito-core", version.ref = "mockitocore" }

plugins/secure-mode/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# DJL Serving - Secure Mode Plugin
2+
3+
This plugin implements SageMaker Secure Mode for the model server, by performing checks for potentially unsafe files and options. It is configured by the SageMaker platform.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
plugins {
2+
ai.djl.javaProject
3+
}
4+
5+
dependencies {
6+
implementation(project(":serving"))
7+
implementation(project(":wlm"))
8+
9+
testImplementation(libs.testng) {
10+
exclude(group = "junit", module = "junit")
11+
}
12+
testImplementation(libs.mockito.core)
13+
14+
}
15+
16+
tasks {
17+
register<Copy>("copyJar") {
18+
from(jar) // here it automatically reads jar file produced from jar task
19+
into("../../serving/plugins")
20+
}
21+
22+
jar { finalizedBy("copyJar") }
23+
}

plugins/secure-mode/gradlew

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../gradlew
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.serving.plugins.securemode;
14+
15+
/** Thrown when Secure Mode encounters a violation during security checks. */
16+
public class IllegalConfigurationException extends RuntimeException {
17+
18+
static final long serialVersionUID = 1L;
19+
20+
/**
21+
* Constructs a {@link IllegalConfigurationException} with the specified detail message.
22+
*
23+
* @param message The detail message (which is saved for later retrieval by the {@link
24+
* #getMessage()} method)
25+
*/
26+
public IllegalConfigurationException(String message) {
27+
super(message);
28+
}
29+
30+
/**
31+
* Constructs a {@link IllegalConfigurationException} with the specified detail message and
32+
* cause.
33+
*
34+
* <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
35+
* incorporated into this exception's detail message.
36+
*
37+
* @param message The detail message (which is saved for later retrieval by the {@link
38+
* #getMessage()} method)
39+
* @param cause The cause (which is saved for later retrieval by the {@link #getCause()}
40+
* method). (A null value is permitted, and indicates that the cause is nonexistent or
41+
* unknown.)
42+
*/
43+
public IllegalConfigurationException(String message, Throwable cause) {
44+
super(message, cause);
45+
}
46+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.serving.plugins.securemode;
14+
15+
import ai.djl.Device;
16+
import ai.djl.ModelException;
17+
import ai.djl.serving.wlm.ModelInfo;
18+
import ai.djl.serving.wlm.util.ModelServerListenerAdapter;
19+
20+
import java.io.IOException;
21+
import java.net.URISyntaxException;
22+
23+
class SecureModeModelServerListener extends ModelServerListenerAdapter {
24+
25+
@Override
26+
public void onModelLoading(ModelInfo<?, ?> model, Device device) {
27+
super.onModelLoading(model, device);
28+
29+
if (SecureModeUtils.isSecureMode()) {
30+
try {
31+
SecureModeUtils.validateSecurity();
32+
SecureModeUtils.reconcileSources(model.getModelUrl());
33+
} catch (ModelException e) {
34+
throw new IllegalConfigurationException("Secure Mode check failed", e);
35+
} catch (IOException | URISyntaxException e) {
36+
throw new IllegalConfigurationException(
37+
"Error while running Secure Mode checks", e);
38+
}
39+
}
40+
}
41+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.serving.plugins.securemode;
14+
15+
import ai.djl.serving.plugins.RequestHandler;
16+
import ai.djl.serving.wlm.util.EventManager;
17+
18+
import io.netty.channel.ChannelHandlerContext;
19+
import io.netty.handler.codec.http.FullHttpRequest;
20+
import io.netty.handler.codec.http.QueryStringDecoder;
21+
22+
/** A plugin for Secure Mode. */
23+
public class SecureModePlugin implements RequestHandler<Void> {
24+
25+
/** Constructs a new {@code SecureModePlugin} instance. */
26+
public SecureModePlugin() {
27+
if (SecureModeUtils.isSecureMode()) {
28+
EventManager.getInstance().addListener(new SecureModeModelServerListener());
29+
}
30+
}
31+
32+
/** {@inheritDoc} */
33+
@Override
34+
public boolean acceptInboundMessage(Object msg) {
35+
return false;
36+
}
37+
38+
/** {@inheritDoc} */
39+
@Override
40+
public Void handleRequest(
41+
ChannelHandlerContext ctx,
42+
FullHttpRequest req,
43+
QueryStringDecoder decoder,
44+
String[] segments) {
45+
return null;
46+
}
47+
}

0 commit comments

Comments
 (0)