Skip to content

Commit 9107883

Browse files
committed
enhance(controller): support dev mode proxy
1 parent ea65f8c commit 9107883

File tree

16 files changed

+617
-72
lines changed

16 files changed

+617
-72
lines changed

server/controller/src/main/java/ai/starwhale/mlops/common/ProxyServlet.java

Lines changed: 219 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,31 @@
1616

1717
package ai.starwhale.mlops.common;
1818

19-
import static ai.starwhale.mlops.domain.job.ModelServingService.MODEL_SERVICE_PREFIX;
20-
21-
import ai.starwhale.mlops.domain.job.ModelServingService;
22-
import ai.starwhale.mlops.domain.job.mapper.ModelServingMapper;
19+
import ai.starwhale.mlops.common.proxy.Service;
20+
import ai.starwhale.mlops.common.proxy.WebServerInTask;
21+
import ai.starwhale.mlops.configuration.FeaturesProperties;
22+
import java.io.EOFException;
2323
import java.io.IOException;
24+
import java.io.InputStream;
25+
import java.io.OutputStream;
2426
import java.net.NoRouteToHostException;
27+
import java.net.Socket;
2528
import java.net.URI;
2629
import java.net.URISyntaxException;
2730
import java.net.UnknownHostException;
28-
import java.util.Date;
31+
import java.nio.charset.StandardCharsets;
32+
import java.util.List;
33+
import java.util.concurrent.ExecutorService;
34+
import java.util.concurrent.Executors;
35+
import java.util.concurrent.Future;
36+
import java.util.stream.Collectors;
2937
import javax.servlet.ServletException;
3038
import javax.servlet.http.HttpServlet;
3139
import javax.servlet.http.HttpServletRequest;
3240
import javax.servlet.http.HttpServletResponse;
41+
import javax.servlet.http.HttpUpgradeHandler;
42+
import javax.servlet.http.WebConnection;
43+
import lombok.extern.slf4j.Slf4j;
3344
import org.apache.http.HttpHost;
3445
import org.apache.http.HttpRequest;
3546
import org.apache.http.HttpResponse;
@@ -39,20 +50,31 @@
3950
import org.apache.http.entity.InputStreamEntity;
4051
import org.apache.http.impl.client.HttpClientBuilder;
4152
import org.apache.http.message.BasicHttpEntityEnclosingRequest;
53+
import org.apache.http.util.EntityUtils;
4254
import org.springframework.stereotype.Component;
4355
import org.springframework.util.StringUtils;
4456

57+
@Slf4j
4558
@Component
4659
public class ProxyServlet extends HttpServlet {
47-
protected ModelServingMapper modelServingMapper;
4860
protected HttpClient httpClient;
61+
protected ExecutorService exec;
62+
private final List<Service> services;
4963

50-
public ProxyServlet(ModelServingMapper modelServingMapper) {
51-
this.modelServingMapper = modelServingMapper;
64+
public ProxyServlet(FeaturesProperties featuresProperties, List<Service> services) {
65+
this.services = services.stream()
66+
.filter(service -> {
67+
if (!featuresProperties.isJobProxyEnabled()) {
68+
return service instanceof WebServerInTask;
69+
}
70+
return true;
71+
})
72+
.collect(Collectors.toList());
5273
}
5374

5475
@Override
55-
public void init() {
76+
public void init() throws ServletException {
77+
exec = Executors.newCachedThreadPool();
5678
httpClient = HttpClientBuilder.create().setMaxConnTotal(-1).build();
5779
}
5880

@@ -68,18 +90,111 @@ public void service(HttpServletRequest req, HttpServletResponse res) throws Serv
6890
throw new RuntimeException(e);
6991
}
7092

93+
// check if it is a websocket request
94+
if (req.getHeader("Upgrade") != null && req.getHeader("Upgrade").equalsIgnoreCase("websocket")) {
95+
workWithWebSocket(req, res, uri);
96+
return;
97+
}
98+
7199
var host = new HttpHost(URIUtils.extractHost(uri));
72100
var path = uri.getPath();
73101
if (StringUtils.hasText(req.getQueryString())) {
74102
path = path + "?" + req.getQueryString();
75103
}
76104
var request = generateRequest(req, path);
105+
HttpResponse response = null;
77106
try {
78-
var response = httpClient.execute(host, request);
107+
response = httpClient.execute(host, request);
79108
generateResponse(response, res);
80109
} catch (UnknownHostException | NoRouteToHostException | HttpHostConnectException e) {
81110
// return 502 if host or port is unavailable
82111
res.setStatus(HttpServletResponse.SC_BAD_GATEWAY);
112+
} catch (Exception e) {
113+
// return 500 if any other exception
114+
res.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
115+
res.getWriter().println(e.getMessage());
116+
} finally {
117+
if (response != null) {
118+
EntityUtils.consumeQuietly(response.getEntity());
119+
}
120+
}
121+
}
122+
123+
// inspired by https://stackoverflow.com/questions/69482092
124+
private void workWithWebSocket(HttpServletRequest req, HttpServletResponse res, URI uri)
125+
throws IOException, ServletException {
126+
Socket sock = new Socket(uri.getHost(), uri.getPort());
127+
boolean closeSocket = false;
128+
try {
129+
var sockIn = sock.getInputStream();
130+
var sockOut = sock.getOutputStream();
131+
132+
// prepare request header
133+
StringBuilder sb = new StringBuilder(512);
134+
135+
var path = uri.getPath();
136+
if (StringUtils.hasText(req.getQueryString())) {
137+
path = path + "?" + req.getQueryString();
138+
}
139+
140+
sb.append("GET ").append(path).append(" HTTP/1.1");
141+
sb.append("\r\n");
142+
var en = req.getHeaderNames();
143+
while (en.hasMoreElements()) {
144+
var n = en.nextElement();
145+
String header = req.getHeader(n);
146+
sb.append(n).append(": ").append(header).append("\r\n");
147+
}
148+
sb.append("\r\n");
149+
150+
sockOut.write(sb.toString().getBytes(StandardCharsets.UTF_8));
151+
sockOut.flush();
152+
153+
StringBuilder responseBytes = new StringBuilder(512);
154+
int b = 0;
155+
while (b != -1) {
156+
b = sockIn.read();
157+
if (b != -1) {
158+
responseBytes.append((char) b);
159+
var len = responseBytes.length();
160+
if (len >= 4
161+
&& responseBytes.charAt(len - 4) == '\r'
162+
&& responseBytes.charAt(len - 3) == '\n'
163+
&& responseBytes.charAt(len - 2) == '\r'
164+
&& responseBytes.charAt(len - 1) == '\n'
165+
) {
166+
break;
167+
}
168+
}
169+
}
170+
171+
var rows = responseBytes.toString().split("\r\n");
172+
var response = rows[0];
173+
174+
int idx1 = response.indexOf(' ');
175+
int idx2 = response.indexOf(' ', idx1 + 1);
176+
177+
for (int i = 1; i < rows.length; i++) {
178+
String line = rows[i];
179+
int idx3 = line.indexOf(":");
180+
var k = line.substring(0, idx3);
181+
var headerField = line.substring(idx3 + 2);
182+
res.setHeader(k, headerField);
183+
}
184+
185+
int respCode = Integer.parseInt(response.substring(idx1 + 1, idx2));
186+
if (respCode != HttpServletResponse.SC_SWITCHING_PROTOCOLS) {
187+
res.setStatus(respCode);
188+
res.flushBuffer();
189+
closeSocket = true;
190+
} else {
191+
var uh = req.upgrade(WsUpgradeHandler.class);
192+
uh.preInit(exec, sockIn, sockOut, sock);
193+
}
194+
} finally {
195+
if (closeSocket) {
196+
sock.close();
197+
}
83198
}
84199
}
85200

@@ -112,40 +227,120 @@ protected void generateResponse(HttpResponse origin, HttpServletResponse resp) t
112227
for (var header : headers) {
113228
resp.addHeader(header.getName(), header.getValue());
114229
}
230+
if (code == HttpServletResponse.SC_NOT_MODIFIED) {
231+
resp.setIntHeader("Content-Length", 0);
232+
return;
233+
}
234+
115235
var entity = origin.getEntity();
116236
if (entity == null) {
117237
return;
118238
}
119-
entity.writeTo(resp.getOutputStream());
239+
if (!entity.isChunked()) {
240+
entity.writeTo(resp.getOutputStream());
241+
return;
242+
}
243+
var in = entity.getContent();
244+
var out = resp.getOutputStream();
245+
var buffer = new byte[1024];
246+
int len;
247+
while ((len = in.read(buffer)) != -1) {
248+
out.write(buffer, 0, len);
249+
}
250+
out.flush();
120251
}
121252

122253
/**
123-
* get target url to proxy, only support model serving service for now
254+
* get target url to proxy
124255
*
125256
* @param uri original uri
126257
* @return target url
127258
*/
128259
public String getTarget(String uri) {
129260
uri = StringUtils.trimLeadingCharacter(uri, '/');
130-
var parts = uri.split("/", 3);
261+
var parts = uri.split("/", 2);
131262
if (parts.length < 2) {
132263
throw new IllegalArgumentException("can not parse uri " + uri);
133264
}
134-
if (!parts[0].equals(MODEL_SERVICE_PREFIX)) {
135-
throw new IllegalArgumentException("can not recognize prefix " + parts[0]);
265+
var prefix = parts[0];
266+
267+
// find the service by prefix
268+
var service = services.stream().filter(s -> s.getPrefix().equals(prefix)).findFirst();
269+
if (service.isEmpty()) {
270+
throw new IllegalArgumentException("can not find service for prefix " + prefix);
271+
}
272+
return service.get().getTarget(parts[1]);
273+
}
274+
275+
public static class WsUpgradeHandler implements HttpUpgradeHandler {
276+
ExecutorService exec;
277+
InputStream sockIn;
278+
OutputStream sockOut;
279+
Socket sock;
280+
Future<?> future;
281+
282+
public WsUpgradeHandler() {
283+
}
284+
285+
public void preInit(ExecutorService exec, InputStream sockIn, OutputStream sockOut, Socket sock) {
286+
this.exec = exec;
287+
this.sockIn = sockIn;
288+
this.sockOut = sockOut;
289+
this.sock = sock;
136290
}
137-
var id = Long.parseLong(parts[1]);
138291

139-
if (modelServingMapper.find(id) == null) {
140-
throw new IllegalArgumentException("can not find model serving entry " + parts[1]);
292+
@Override
293+
public void init(WebConnection wc) {
294+
try {
295+
var servletIn = wc.getInputStream();
296+
var servletOut = wc.getOutputStream();
297+
future = exec.submit(() -> {
298+
// read from sockIn and write to servletOut
299+
try {
300+
var buffer = new byte[1024];
301+
int len;
302+
while ((len = sockIn.read(buffer)) != -1) {
303+
servletOut.write(buffer, 0, len);
304+
servletOut.flush();
305+
}
306+
} catch (IOException ex) {
307+
log.error("error in websocket handler", ex);
308+
}
309+
310+
return null;
311+
});
312+
313+
// read from servletIn and write to sockOut
314+
var buffer = new byte[1024];
315+
int len;
316+
while ((len = servletIn.read(buffer)) != -1) {
317+
sockOut.write(buffer, 0, len);
318+
sockOut.flush();
319+
}
320+
321+
future.get();
322+
} catch (InterruptedException | EOFException ex) {
323+
log.info("websocket closed");
324+
} catch (Exception e) {
325+
log.error("error in websocket handler", e);
326+
} finally {
327+
if (future != null) {
328+
future.cancel(true);
329+
}
330+
}
141331
}
142-
modelServingMapper.updateLastVisitTime(id, new Date());
143332

144-
var svc = ModelServingService.getServiceName(id);
145-
var handler = "";
146-
if (parts.length == 3) {
147-
handler = parts[2];
333+
@Override
334+
public void destroy() {
335+
if (future != null) {
336+
future.cancel(true);
337+
}
338+
try {
339+
sock.close();
340+
} catch (IOException ex) {
341+
log.error("error closing socket", ex);
342+
}
148343
}
149-
return String.format("http://%s/%s", svc, handler);
344+
150345
}
151346
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package ai.starwhale.mlops.common.proxy;
18+
19+
import ai.starwhale.mlops.domain.job.ModelServingService;
20+
import ai.starwhale.mlops.domain.job.mapper.ModelServingMapper;
21+
import java.util.Date;
22+
import org.springframework.stereotype.Component;
23+
24+
/**
25+
* This class is used to proxy the model serving service.
26+
* The model serving service uri is like "model-serving/1/xxx", the first number is the id of the model serving entry.
27+
* The proxy will find the target host by the id.
28+
* The proxy will update the last visit time of the model serving entry which is used to do the garbage collection.
29+
*/
30+
@Component
31+
public class ModelServing implements Service {
32+
private final ModelServingMapper modelServingMapper;
33+
34+
public static final String MODEL_SERVICE_PREFIX = "model-serving";
35+
36+
public ModelServing(ModelServingMapper modelServingMapper) {
37+
this.modelServingMapper = modelServingMapper;
38+
}
39+
40+
@Override
41+
public String getPrefix() {
42+
return MODEL_SERVICE_PREFIX;
43+
}
44+
45+
@Override
46+
public String getTarget(String uri) {
47+
var parts = uri.split("/", 2);
48+
49+
var id = Long.parseLong(parts[0]);
50+
51+
if (modelServingMapper.find(id) == null) {
52+
throw new IllegalArgumentException("can not find model serving entry " + parts[1]);
53+
}
54+
modelServingMapper.updateLastVisitTime(id, new Date());
55+
56+
var svc = ModelServingService.getServiceName(id);
57+
var handler = "";
58+
if (parts.length == 2) {
59+
handler = parts[1];
60+
}
61+
return String.format("http://%s/%s", svc, handler);
62+
}
63+
}

0 commit comments

Comments
 (0)