16
16
17
17
package ai .starwhale .mlops .common ;
18
18
19
- import static ai .starwhale .mlops .domain . job . ModelServingService . MODEL_SERVICE_PREFIX ;
20
-
21
- import ai .starwhale .mlops .domain . job . ModelServingService ;
22
- import ai . starwhale . mlops . domain . job . mapper . ModelServingMapper ;
19
+ import ai .starwhale .mlops .common . proxy . Service ;
20
+ import ai . starwhale . mlops . common . proxy . WebServerInTask ;
21
+ import ai .starwhale .mlops .configuration . FeaturesProperties ;
22
+ import java . io . EOFException ;
23
23
import java .io .IOException ;
24
+ import java .io .InputStream ;
25
+ import java .io .OutputStream ;
24
26
import java .net .NoRouteToHostException ;
27
+ import java .net .Socket ;
25
28
import java .net .URI ;
26
29
import java .net .URISyntaxException ;
27
30
import java .net .UnknownHostException ;
28
- import java .util .Date ;
31
+ import java .nio .charset .StandardCharsets ;
32
+ import java .util .List ;
33
+ import java .util .concurrent .ExecutorService ;
34
+ import java .util .concurrent .Executors ;
35
+ import java .util .concurrent .Future ;
36
+ import java .util .stream .Collectors ;
29
37
import javax .servlet .ServletException ;
30
38
import javax .servlet .http .HttpServlet ;
31
39
import javax .servlet .http .HttpServletRequest ;
32
40
import javax .servlet .http .HttpServletResponse ;
41
+ import javax .servlet .http .HttpUpgradeHandler ;
42
+ import javax .servlet .http .WebConnection ;
43
+ import lombok .extern .slf4j .Slf4j ;
33
44
import org .apache .http .HttpHost ;
34
45
import org .apache .http .HttpRequest ;
35
46
import org .apache .http .HttpResponse ;
39
50
import org .apache .http .entity .InputStreamEntity ;
40
51
import org .apache .http .impl .client .HttpClientBuilder ;
41
52
import org .apache .http .message .BasicHttpEntityEnclosingRequest ;
53
+ import org .apache .http .util .EntityUtils ;
42
54
import org .springframework .stereotype .Component ;
43
55
import org .springframework .util .StringUtils ;
44
56
57
+ @ Slf4j
45
58
@ Component
46
59
public class ProxyServlet extends HttpServlet {
47
- protected ModelServingMapper modelServingMapper ;
48
60
protected HttpClient httpClient ;
61
+ protected ExecutorService exec ;
62
+ private final List <Service > services ;
49
63
50
- public ProxyServlet (ModelServingMapper modelServingMapper ) {
51
- this .modelServingMapper = modelServingMapper ;
64
+ public ProxyServlet (FeaturesProperties featuresProperties , List <Service > services ) {
65
+ this .services = services .stream ()
66
+ .filter (service -> {
67
+ if (!featuresProperties .isJobProxyEnabled ()) {
68
+ return service instanceof WebServerInTask ;
69
+ }
70
+ return true ;
71
+ })
72
+ .collect (Collectors .toList ());
52
73
}
53
74
54
75
@ Override
55
- public void init () {
76
+ public void init () throws ServletException {
77
+ exec = Executors .newCachedThreadPool ();
56
78
httpClient = HttpClientBuilder .create ().setMaxConnTotal (-1 ).build ();
57
79
}
58
80
@@ -68,18 +90,111 @@ public void service(HttpServletRequest req, HttpServletResponse res) throws Serv
68
90
throw new RuntimeException (e );
69
91
}
70
92
93
+ // check if it is a websocket request
94
+ if (req .getHeader ("Upgrade" ) != null && req .getHeader ("Upgrade" ).equalsIgnoreCase ("websocket" )) {
95
+ workWithWebSocket (req , res , uri );
96
+ return ;
97
+ }
98
+
71
99
var host = new HttpHost (URIUtils .extractHost (uri ));
72
100
var path = uri .getPath ();
73
101
if (StringUtils .hasText (req .getQueryString ())) {
74
102
path = path + "?" + req .getQueryString ();
75
103
}
76
104
var request = generateRequest (req , path );
105
+ HttpResponse response = null ;
77
106
try {
78
- var response = httpClient .execute (host , request );
107
+ response = httpClient .execute (host , request );
79
108
generateResponse (response , res );
80
109
} catch (UnknownHostException | NoRouteToHostException | HttpHostConnectException e ) {
81
110
// return 502 if host or port is unavailable
82
111
res .setStatus (HttpServletResponse .SC_BAD_GATEWAY );
112
+ } catch (Exception e ) {
113
+ // return 500 if any other exception
114
+ res .setStatus (HttpServletResponse .SC_INTERNAL_SERVER_ERROR );
115
+ res .getWriter ().println (e .getMessage ());
116
+ } finally {
117
+ if (response != null ) {
118
+ EntityUtils .consumeQuietly (response .getEntity ());
119
+ }
120
+ }
121
+ }
122
+
123
+ // inspired by https://stackoverflow.com/questions/69482092
124
+ private void workWithWebSocket (HttpServletRequest req , HttpServletResponse res , URI uri )
125
+ throws IOException , ServletException {
126
+ Socket sock = new Socket (uri .getHost (), uri .getPort ());
127
+ boolean closeSocket = false ;
128
+ try {
129
+ var sockIn = sock .getInputStream ();
130
+ var sockOut = sock .getOutputStream ();
131
+
132
+ // prepare request header
133
+ StringBuilder sb = new StringBuilder (512 );
134
+
135
+ var path = uri .getPath ();
136
+ if (StringUtils .hasText (req .getQueryString ())) {
137
+ path = path + "?" + req .getQueryString ();
138
+ }
139
+
140
+ sb .append ("GET " ).append (path ).append (" HTTP/1.1" );
141
+ sb .append ("\r \n " );
142
+ var en = req .getHeaderNames ();
143
+ while (en .hasMoreElements ()) {
144
+ var n = en .nextElement ();
145
+ String header = req .getHeader (n );
146
+ sb .append (n ).append (": " ).append (header ).append ("\r \n " );
147
+ }
148
+ sb .append ("\r \n " );
149
+
150
+ sockOut .write (sb .toString ().getBytes (StandardCharsets .UTF_8 ));
151
+ sockOut .flush ();
152
+
153
+ StringBuilder responseBytes = new StringBuilder (512 );
154
+ int b = 0 ;
155
+ while (b != -1 ) {
156
+ b = sockIn .read ();
157
+ if (b != -1 ) {
158
+ responseBytes .append ((char ) b );
159
+ var len = responseBytes .length ();
160
+ if (len >= 4
161
+ && responseBytes .charAt (len - 4 ) == '\r'
162
+ && responseBytes .charAt (len - 3 ) == '\n'
163
+ && responseBytes .charAt (len - 2 ) == '\r'
164
+ && responseBytes .charAt (len - 1 ) == '\n'
165
+ ) {
166
+ break ;
167
+ }
168
+ }
169
+ }
170
+
171
+ var rows = responseBytes .toString ().split ("\r \n " );
172
+ var response = rows [0 ];
173
+
174
+ int idx1 = response .indexOf (' ' );
175
+ int idx2 = response .indexOf (' ' , idx1 + 1 );
176
+
177
+ for (int i = 1 ; i < rows .length ; i ++) {
178
+ String line = rows [i ];
179
+ int idx3 = line .indexOf (":" );
180
+ var k = line .substring (0 , idx3 );
181
+ var headerField = line .substring (idx3 + 2 );
182
+ res .setHeader (k , headerField );
183
+ }
184
+
185
+ int respCode = Integer .parseInt (response .substring (idx1 + 1 , idx2 ));
186
+ if (respCode != HttpServletResponse .SC_SWITCHING_PROTOCOLS ) {
187
+ res .setStatus (respCode );
188
+ res .flushBuffer ();
189
+ closeSocket = true ;
190
+ } else {
191
+ var uh = req .upgrade (WsUpgradeHandler .class );
192
+ uh .preInit (exec , sockIn , sockOut , sock );
193
+ }
194
+ } finally {
195
+ if (closeSocket ) {
196
+ sock .close ();
197
+ }
83
198
}
84
199
}
85
200
@@ -112,40 +227,120 @@ protected void generateResponse(HttpResponse origin, HttpServletResponse resp) t
112
227
for (var header : headers ) {
113
228
resp .addHeader (header .getName (), header .getValue ());
114
229
}
230
+ if (code == HttpServletResponse .SC_NOT_MODIFIED ) {
231
+ resp .setIntHeader ("Content-Length" , 0 );
232
+ return ;
233
+ }
234
+
115
235
var entity = origin .getEntity ();
116
236
if (entity == null ) {
117
237
return ;
118
238
}
119
- entity .writeTo (resp .getOutputStream ());
239
+ if (!entity .isChunked ()) {
240
+ entity .writeTo (resp .getOutputStream ());
241
+ return ;
242
+ }
243
+ var in = entity .getContent ();
244
+ var out = resp .getOutputStream ();
245
+ var buffer = new byte [1024 ];
246
+ int len ;
247
+ while ((len = in .read (buffer )) != -1 ) {
248
+ out .write (buffer , 0 , len );
249
+ }
250
+ out .flush ();
120
251
}
121
252
122
253
/**
123
- * get target url to proxy, only support model serving service for now
254
+ * get target url to proxy
124
255
*
125
256
* @param uri original uri
126
257
* @return target url
127
258
*/
128
259
public String getTarget (String uri ) {
129
260
uri = StringUtils .trimLeadingCharacter (uri , '/' );
130
- var parts = uri .split ("/" , 3 );
261
+ var parts = uri .split ("/" , 2 );
131
262
if (parts .length < 2 ) {
132
263
throw new IllegalArgumentException ("can not parse uri " + uri );
133
264
}
134
- if (!parts [0 ].equals (MODEL_SERVICE_PREFIX )) {
135
- throw new IllegalArgumentException ("can not recognize prefix " + parts [0 ]);
265
+ var prefix = parts [0 ];
266
+
267
+ // find the service by prefix
268
+ var service = services .stream ().filter (s -> s .getPrefix ().equals (prefix )).findFirst ();
269
+ if (service .isEmpty ()) {
270
+ throw new IllegalArgumentException ("can not find service for prefix " + prefix );
271
+ }
272
+ return service .get ().getTarget (parts [1 ]);
273
+ }
274
+
275
+ public static class WsUpgradeHandler implements HttpUpgradeHandler {
276
+ ExecutorService exec ;
277
+ InputStream sockIn ;
278
+ OutputStream sockOut ;
279
+ Socket sock ;
280
+ Future <?> future ;
281
+
282
+ public WsUpgradeHandler () {
283
+ }
284
+
285
+ public void preInit (ExecutorService exec , InputStream sockIn , OutputStream sockOut , Socket sock ) {
286
+ this .exec = exec ;
287
+ this .sockIn = sockIn ;
288
+ this .sockOut = sockOut ;
289
+ this .sock = sock ;
136
290
}
137
- var id = Long .parseLong (parts [1 ]);
138
291
139
- if (modelServingMapper .find (id ) == null ) {
140
- throw new IllegalArgumentException ("can not find model serving entry " + parts [1 ]);
292
+ @ Override
293
+ public void init (WebConnection wc ) {
294
+ try {
295
+ var servletIn = wc .getInputStream ();
296
+ var servletOut = wc .getOutputStream ();
297
+ future = exec .submit (() -> {
298
+ // read from sockIn and write to servletOut
299
+ try {
300
+ var buffer = new byte [1024 ];
301
+ int len ;
302
+ while ((len = sockIn .read (buffer )) != -1 ) {
303
+ servletOut .write (buffer , 0 , len );
304
+ servletOut .flush ();
305
+ }
306
+ } catch (IOException ex ) {
307
+ log .error ("error in websocket handler" , ex );
308
+ }
309
+
310
+ return null ;
311
+ });
312
+
313
+ // read from servletIn and write to sockOut
314
+ var buffer = new byte [1024 ];
315
+ int len ;
316
+ while ((len = servletIn .read (buffer )) != -1 ) {
317
+ sockOut .write (buffer , 0 , len );
318
+ sockOut .flush ();
319
+ }
320
+
321
+ future .get ();
322
+ } catch (InterruptedException | EOFException ex ) {
323
+ log .info ("websocket closed" );
324
+ } catch (Exception e ) {
325
+ log .error ("error in websocket handler" , e );
326
+ } finally {
327
+ if (future != null ) {
328
+ future .cancel (true );
329
+ }
330
+ }
141
331
}
142
- modelServingMapper .updateLastVisitTime (id , new Date ());
143
332
144
- var svc = ModelServingService .getServiceName (id );
145
- var handler = "" ;
146
- if (parts .length == 3 ) {
147
- handler = parts [2 ];
333
+ @ Override
334
+ public void destroy () {
335
+ if (future != null ) {
336
+ future .cancel (true );
337
+ }
338
+ try {
339
+ sock .close ();
340
+ } catch (IOException ex ) {
341
+ log .error ("error closing socket" , ex );
342
+ }
148
343
}
149
- return String . format ( "http://%s/%s" , svc , handler );
344
+
150
345
}
151
346
}
0 commit comments