@@ -10,6 +10,7 @@ the Change License after the Change Date as each is defined in accordance with t
10
10
package agent
11
11
12
12
import (
13
+ "bufio"
13
14
"bytes"
14
15
"context"
15
16
"fmt"
@@ -49,6 +50,32 @@ func (mlserver *mockMLServerState) v2Infer(w http.ResponseWriter, req *http.Requ
49
50
_ , _ = w .Write ([]byte ("Model inference: " + modelName ))
50
51
}
51
52
53
+ func (mlserver * mockMLServerState ) v2InferStream (w http.ResponseWriter , req * http.Request ) {
54
+ params := mux .Vars (req )
55
+ modelName := params ["model_name" ]
56
+ if _ , ok := mlserver .modelsNotFound [modelName ]; ok {
57
+ http .NotFound (w , req )
58
+ }
59
+
60
+ w .Header ().Set ("Access-Control-Allow-Origin" , "*" )
61
+ w .Header ().Set ("Access-Control-Allow-Methods" , "POST, OPTIONS" )
62
+ w .Header ().Set ("Access-Control-Allow-Headers" , "Content-Type" )
63
+ w .Header ().Set ("Content-Type" , "text/plain" )
64
+
65
+ flusher , ok := w .(http.Flusher )
66
+ if ! ok {
67
+ http .Error (w , "Streaming unsupported" , http .StatusInternalServerError )
68
+ return
69
+ }
70
+
71
+ chunks := []string {"Model " , "inference: " , modelName }
72
+ for _ , chunk := range chunks {
73
+ newLineChunk := chunk + "\n "
74
+ _ , _ = w .Write ([]byte (newLineChunk )) // Write a chunk
75
+ flusher .Flush () // Flush to send immediately
76
+ }
77
+ }
78
+
52
79
func (mlserver * mockMLServerState ) v2Load (w http.ResponseWriter , req * http.Request ) {
53
80
params := mux .Vars (req )
54
81
modelName := params ["model_name" ]
@@ -87,6 +114,7 @@ func (mlserver *mockMLServerState) isModelLoaded(modelId string) bool {
87
114
func setupMockMLServer (mockMLServerState * mockMLServerState , serverPort int ) * http.Server {
88
115
rtr := mux .NewRouter ()
89
116
rtr .HandleFunc ("/v2/models/{model_name:\\ w+}/infer" , mockMLServerState .v2Infer ).Methods ("POST" )
117
+ rtr .HandleFunc ("/v2/models/{model_name:\\ w+}/infer_stream" , mockMLServerState .v2InferStream ).Methods ("POST" )
90
118
rtr .HandleFunc ("/v2/repository/models/{model_name:\\ w+}/load" , mockMLServerState .v2Load ).Methods ("POST" )
91
119
rtr .HandleFunc ("/v2/repository/models/{model_name:\\ w+}/unload" , mockMLServerState .v2Unload ).Methods ("POST" )
92
120
return & http.Server {Addr : ":" + strconv .Itoa (serverPort ), Handler : rtr }
@@ -266,13 +294,21 @@ func TestReverseProxySmoke(t *testing.T) {
266
294
}
267
295
268
296
// make a dummy predict call with any model name, URL does not matter, only headers
269
- inferV2Path := "/v2/models/RANDOM/infer"
270
- url := "http://localhost:" + strconv .Itoa (rpPort ) + inferV2Path
271
- req , err := http .NewRequest (http .MethodPost , url , nil )
272
- g .Expect (err ).To (BeNil ())
273
- req .Header .Set ("contentType" , "application/json" )
274
- req .Header .Set (util .SeldonModelHeader , test .modelExternalHeader )
275
- req .Header .Set (util .SeldonInternalModelHeader , test .modelToRequest )
297
+ createRequest := func (endpoint string ) * http.Request {
298
+ inferV2Path := "/v2/models/RANDOM/" + endpoint
299
+ logger .Debug ("inferV2Path:" , inferV2Path )
300
+
301
+ url := "http://localhost:" + strconv .Itoa (rpPort ) + inferV2Path
302
+ req , err := http .NewRequest (http .MethodPost , url , nil )
303
+ g .Expect (err ).To (BeNil ())
304
+ req .Header .Set ("contentType" , "application/json" )
305
+ req .Header .Set (util .SeldonModelHeader , test .modelExternalHeader )
306
+ req .Header .Set (util .SeldonInternalModelHeader , test .modelToRequest )
307
+ return req
308
+ }
309
+
310
+ // infer request
311
+ req := createRequest ("infer" )
276
312
resp , err := http .DefaultClient .Do (req )
277
313
g .Expect (err ).To (BeNil ())
278
314
@@ -284,6 +320,25 @@ func TestReverseProxySmoke(t *testing.T) {
284
320
g .Expect (strings .Contains (bodyString , test .modelToLoad )).To (BeTrue ())
285
321
}
286
322
323
+ // infer_stream request
324
+ req = createRequest ("infer_stream" )
325
+ resp , err = http .DefaultClient .Do (req )
326
+ g .Expect (err ).To (BeNil ())
327
+
328
+ g .Expect (resp .StatusCode ).To (Equal (test .statusCode ))
329
+ if test .statusCode == http .StatusOK {
330
+ scanner := bufio .NewScanner (resp .Body )
331
+ messages := make ([]string , 0 )
332
+ for scanner .Scan () {
333
+ messages = append (messages , scanner .Text ())
334
+ }
335
+
336
+ g .Expect (scanner .Err ()).To (BeNil ())
337
+
338
+ messages_concat := strings .Join (messages , "" )
339
+ g .Expect (strings .Contains (messages_concat , test .modelToLoad )).To (BeTrue ())
340
+ }
341
+
287
342
// test model scaling stats
288
343
if test .statusCode == http .StatusOK {
289
344
g .Expect (rpHTTP .modelScalingStatsCollector .ModelLagStats .Get (test .modelToRequest )).To (Equal (uint32 (0 )))
0 commit comments