Skip to content

Commit b5df482

Browse files
authored
[serving] make http response codes configurable for exception cases (#2114)
1 parent c578e6f commit b5df482

File tree

3 files changed

+65
-8
lines changed

3 files changed

+65
-8
lines changed

serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import ai.djl.serving.util.ConfigManager;
2727
import ai.djl.serving.util.NettyUtils;
2828
import ai.djl.serving.wlm.ModelInfo;
29+
import ai.djl.serving.wlm.util.WlmCapacityException;
2930
import ai.djl.serving.wlm.util.WlmException;
3031
import ai.djl.serving.workflow.Workflow;
3132
import ai.djl.translate.TranslateException;
@@ -468,28 +469,33 @@ void sendOutput(Output output, ChannelHandlerContext ctx) {
468469
}
469470

470471
void onException(Throwable t, ChannelHandlerContext ctx) {
471-
HttpResponseStatus status;
472+
int code;
472473
if (t instanceof TranslateException || t instanceof BadRequestException) {
473474
logger.debug(t.getMessage(), t);
474475
SERVER_METRIC.info("{}", RESPONSE_4_XX);
475-
status = HttpResponseStatus.BAD_REQUEST;
476+
code = config.getBadRequestErrorHttpCode();
476477
} else if (t instanceof WlmException) {
477478
logger.warn(t.getMessage(), t);
478479
SERVER_METRIC.info("{}", RESPONSE_5_XX);
479480
SERVER_METRIC.info("{}", WLM_ERROR);
480-
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
481+
if (t instanceof WlmCapacityException) {
482+
code = config.getThrottleErrorHttpCode();
483+
} else {
484+
code = config.getWlmErrorHttpCode();
485+
}
481486
if (!exceedErrorRate && config.onWlmError()) {
482487
exceedErrorRate = true;
483488
}
484489
} else {
485490
logger.warn("Unexpected error", t);
486491
SERVER_METRIC.info("{}", RESPONSE_5_XX);
487492
SERVER_METRIC.info("{}", SERVER_ERROR);
488-
status = HttpResponseStatus.INTERNAL_SERVER_ERROR;
493+
code = config.getServerErrorHttpCode();
489494
if (!exceedErrorRate && config.onServerError()) {
490495
exceedErrorRate = true;
491496
}
492497
}
498+
HttpResponseStatus status = HttpResponseStatus.valueOf(code);
493499

494500
/*
495501
* We can load the models based on the configuration file.Since this Job is

serving/src/main/java/ai/djl/serving/util/ConfigManager.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ public final class ConfigManager {
8484
private static final String ERROR_RATE_SERVER = "error_rate_server";
8585
private static final String ERROR_RATE_MODEL = "error_rate_model";
8686
private static final String ERROR_RATE_ANY = "error_rate_any";
87+
private static final String BAD_REQUEST_ERROR_HTTP_CODE = "bad_request_http_code";
88+
private static final String WLM_ERROR_HTTP_CODE = "wlm_error_http_code";
89+
private static final String THROTTLE_ERROR_HTTP_CODE = "throttle_error_http_code";
90+
private static final String TIMEOUT_ERROR_HTTP_CODE = "timeout_http_code";
91+
private static final String SERVER_ERROR_HTTP_CODE = "server_error_http_code";
8792

8893
// Configuration which are not documented or enabled through environment variables
8994
private static final String USE_NATIVE_IO = "use_native_io";
@@ -443,6 +448,51 @@ public int getChunkedReadTimeout() {
443448
return getIntProperty(CHUNKED_READ_TIMEOUT, 60);
444449
}
445450

451+
/**
452+
* Returns the http response status code to use for bad request errors.
453+
*
454+
* @return the http response status code to use for bad request errors
455+
*/
456+
public int getBadRequestErrorHttpCode() {
457+
return getIntProperty(BAD_REQUEST_ERROR_HTTP_CODE, 400);
458+
}
459+
460+
/**
461+
* Returns the http response status code to use for WorkLoadManager errors.
462+
*
463+
* @return the http response status code to use for WorkLoadManager errors
464+
*/
465+
public int getWlmErrorHttpCode() {
466+
return getIntProperty(WLM_ERROR_HTTP_CODE, 503);
467+
}
468+
469+
/**
470+
* Returns the http response status code to use for throttling errors.
471+
*
472+
* @return the http response status code to use for throttling errors
473+
*/
474+
public int getThrottleErrorHttpCode() {
475+
return getIntProperty(THROTTLE_ERROR_HTTP_CODE, 503);
476+
}
477+
478+
/**
479+
* Returns the http response status code to use for Request Timeout errors.
480+
*
481+
* @return the http response status code to use for Request Timeout errors
482+
*/
483+
public int getTimeoutErrorHttpCode() {
484+
return getIntProperty(TIMEOUT_ERROR_HTTP_CODE, 400);
485+
}
486+
487+
/**
488+
* Returns the http response status code to use for generic Server errors.
489+
*
490+
* @return the http response status code to use for generic Server errors
491+
*/
492+
public int getServerErrorHttpCode() {
493+
return getIntProperty(SERVER_ERROR_HTTP_CODE, 500);
494+
}
495+
446496
/**
447497
* Returns the value with the specified key in this configuration.
448498
*

serving/src/test/java/ai/djl/serving/ModelServerTest.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -954,8 +954,9 @@ private void testThrottle() throws InterruptedException {
954954
if (CudaUtils.getGpuCount() <= 1) {
955955
// one request is not able to saturate workers in multi-GPU case
956956
// one of the request will be throttled
957-
if ((httpStatus.code() != 503 || httpStatus2.code() != 200)
958-
&& (httpStatus2.code() != 503 || httpStatus.code() != 200)) {
957+
int throttleCode = configManager.getThrottleErrorHttpCode();
958+
if ((httpStatus.code() != throttleCode || httpStatus2.code() != 200)
959+
&& (httpStatus2.code() != throttleCode || httpStatus.code() != 200)) {
959960
logger.info("request 1 code: {}, request 2 code: {}", httpStatus, httpStatus2);
960961
Assert.fail("Expected one of the request be throttled.");
961962
}
@@ -1357,7 +1358,7 @@ private void testRegisterModelMissingUrl() throws InterruptedException {
13571358

13581359
if (!System.getProperty("os.name").startsWith("Win")) {
13591360
ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class);
1360-
assertEquals(resp.getCode(), HttpResponseStatus.BAD_REQUEST.code());
1361+
assertEquals(resp.getCode(), configManager.getBadRequestErrorHttpCode());
13611362
assertEquals(resp.getMessage(), "Parameter url is required.");
13621363
}
13631364
}
@@ -1453,7 +1454,7 @@ private void testServiceUnavailable() throws InterruptedException {
14531454

14541455
if (!System.getProperty("os.name").startsWith("Win")) {
14551456
ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class);
1456-
assertEquals(resp.getCode(), HttpResponseStatus.SERVICE_UNAVAILABLE.code());
1457+
assertEquals(resp.getCode(), configManager.getWlmErrorHttpCode());
14571458
assertEquals(resp.getMessage(), "All model workers has been shutdown: mlp_2");
14581459
}
14591460

0 commit comments

Comments
 (0)