Skip to content

Commit 67e1a22

Browse files
test(scheduler): Included infer_stream test for REST (SeldonIO#6292)
* Included infer_stream test for REST * Fixed linting
1 parent 32bf7e6 commit 67e1a22

File tree

1 file changed

+62
-7
lines changed

1 file changed

+62
-7
lines changed

scheduler/pkg/agent/rproxy_test.go

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ the Change License after the Change Date as each is defined in accordance with t
1010
package agent
1111

1212
import (
13+
"bufio"
1314
"bytes"
1415
"context"
1516
"fmt"
@@ -49,6 +50,32 @@ func (mlserver *mockMLServerState) v2Infer(w http.ResponseWriter, req *http.Requ
4950
_, _ = w.Write([]byte("Model inference: " + modelName))
5051
}
5152

53+
func (mlserver *mockMLServerState) v2InferStream(w http.ResponseWriter, req *http.Request) {
54+
params := mux.Vars(req)
55+
modelName := params["model_name"]
56+
if _, ok := mlserver.modelsNotFound[modelName]; ok {
57+
http.NotFound(w, req)
58+
}
59+
60+
w.Header().Set("Access-Control-Allow-Origin", "*")
61+
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
62+
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
63+
w.Header().Set("Content-Type", "text/plain")
64+
65+
flusher, ok := w.(http.Flusher)
66+
if !ok {
67+
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
68+
return
69+
}
70+
71+
chunks := []string{"Model ", "inference: ", modelName}
72+
for _, chunk := range chunks {
73+
newLineChunk := chunk + "\n"
74+
_, _ = w.Write([]byte(newLineChunk)) // Write a chunk
75+
flusher.Flush() // Flush to send immediately
76+
}
77+
}
78+
5279
func (mlserver *mockMLServerState) v2Load(w http.ResponseWriter, req *http.Request) {
5380
params := mux.Vars(req)
5481
modelName := params["model_name"]
@@ -87,6 +114,7 @@ func (mlserver *mockMLServerState) isModelLoaded(modelId string) bool {
87114
func setupMockMLServer(mockMLServerState *mockMLServerState, serverPort int) *http.Server {
88115
rtr := mux.NewRouter()
89116
rtr.HandleFunc("/v2/models/{model_name:\\w+}/infer", mockMLServerState.v2Infer).Methods("POST")
117+
rtr.HandleFunc("/v2/models/{model_name:\\w+}/infer_stream", mockMLServerState.v2InferStream).Methods("POST")
90118
rtr.HandleFunc("/v2/repository/models/{model_name:\\w+}/load", mockMLServerState.v2Load).Methods("POST")
91119
rtr.HandleFunc("/v2/repository/models/{model_name:\\w+}/unload", mockMLServerState.v2Unload).Methods("POST")
92120
return &http.Server{Addr: ":" + strconv.Itoa(serverPort), Handler: rtr}
@@ -266,13 +294,21 @@ func TestReverseProxySmoke(t *testing.T) {
266294
}
267295

268296
// make a dummy predict call with any model name, URL does not matter, only headers
269-
inferV2Path := "/v2/models/RANDOM/infer"
270-
url := "http://localhost:" + strconv.Itoa(rpPort) + inferV2Path
271-
req, err := http.NewRequest(http.MethodPost, url, nil)
272-
g.Expect(err).To(BeNil())
273-
req.Header.Set("contentType", "application/json")
274-
req.Header.Set(util.SeldonModelHeader, test.modelExternalHeader)
275-
req.Header.Set(util.SeldonInternalModelHeader, test.modelToRequest)
297+
createRequest := func(endpoint string) *http.Request {
298+
inferV2Path := "/v2/models/RANDOM/" + endpoint
299+
logger.Debug("inferV2Path:", inferV2Path)
300+
301+
url := "http://localhost:" + strconv.Itoa(rpPort) + inferV2Path
302+
req, err := http.NewRequest(http.MethodPost, url, nil)
303+
g.Expect(err).To(BeNil())
304+
req.Header.Set("contentType", "application/json")
305+
req.Header.Set(util.SeldonModelHeader, test.modelExternalHeader)
306+
req.Header.Set(util.SeldonInternalModelHeader, test.modelToRequest)
307+
return req
308+
}
309+
310+
// infer request
311+
req := createRequest("infer")
276312
resp, err := http.DefaultClient.Do(req)
277313
g.Expect(err).To(BeNil())
278314

@@ -284,6 +320,25 @@ func TestReverseProxySmoke(t *testing.T) {
284320
g.Expect(strings.Contains(bodyString, test.modelToLoad)).To(BeTrue())
285321
}
286322

323+
// infer_stream request
324+
req = createRequest("infer_stream")
325+
resp, err = http.DefaultClient.Do(req)
326+
g.Expect(err).To(BeNil())
327+
328+
g.Expect(resp.StatusCode).To(Equal(test.statusCode))
329+
if test.statusCode == http.StatusOK {
330+
scanner := bufio.NewScanner(resp.Body)
331+
messages := make([]string, 0)
332+
for scanner.Scan() {
333+
messages = append(messages, scanner.Text())
334+
}
335+
336+
g.Expect(scanner.Err()).To(BeNil())
337+
338+
messages_concat := strings.Join(messages, "")
339+
g.Expect(strings.Contains(messages_concat, test.modelToLoad)).To(BeTrue())
340+
}
341+
287342
// test model scaling stats
288343
if test.statusCode == http.StatusOK {
289344
g.Expect(rpHTTP.modelScalingStatsCollector.ModelLagStats.Get(test.modelToRequest)).To(Equal(uint32(0)))

0 commit comments

Comments
 (0)