Skip to content

test(scheduler): Included infer_stream test for REST #6292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions scheduler/pkg/agent/rproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ the Change License after the Change Date as each is defined in accordance with t
package agent

import (
"bufio"
"bytes"
"context"
"fmt"
Expand Down Expand Up @@ -49,6 +50,32 @@ func (mlserver *mockMLServerState) v2Infer(w http.ResponseWriter, req *http.Requ
_, _ = w.Write([]byte("Model inference: " + modelName))
}

func (mlserver *mockMLServerState) v2InferStream(w http.ResponseWriter, req *http.Request) {
params := mux.Vars(req)
modelName := params["model_name"]
if _, ok := mlserver.modelsNotFound[modelName]; ok {
http.NotFound(w, req)
}

w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
w.Header().Set("Content-Type", "text/plain")

flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
return
}

chunks := []string{"Model ", "inference: ", modelName}
for _, chunk := range chunks {
newLineChunk := chunk + "\n"
_, _ = w.Write([]byte(newLineChunk)) // Write a chunk
flusher.Flush() // Flush to send immediately
}
}

func (mlserver *mockMLServerState) v2Load(w http.ResponseWriter, req *http.Request) {
params := mux.Vars(req)
modelName := params["model_name"]
Expand Down Expand Up @@ -87,6 +114,7 @@ func (mlserver *mockMLServerState) isModelLoaded(modelId string) bool {
func setupMockMLServer(mockMLServerState *mockMLServerState, serverPort int) *http.Server {
rtr := mux.NewRouter()
rtr.HandleFunc("/v2/models/{model_name:\\w+}/infer", mockMLServerState.v2Infer).Methods("POST")
rtr.HandleFunc("/v2/models/{model_name:\\w+}/infer_stream", mockMLServerState.v2InferStream).Methods("POST")
rtr.HandleFunc("/v2/repository/models/{model_name:\\w+}/load", mockMLServerState.v2Load).Methods("POST")
rtr.HandleFunc("/v2/repository/models/{model_name:\\w+}/unload", mockMLServerState.v2Unload).Methods("POST")
return &http.Server{Addr: ":" + strconv.Itoa(serverPort), Handler: rtr}
Expand Down Expand Up @@ -266,13 +294,21 @@ func TestReverseProxySmoke(t *testing.T) {
}

// make a dummy predict call with any model name, URL does not matter, only headers
inferV2Path := "/v2/models/RANDOM/infer"
url := "http://localhost:" + strconv.Itoa(rpPort) + inferV2Path
req, err := http.NewRequest(http.MethodPost, url, nil)
g.Expect(err).To(BeNil())
req.Header.Set("contentType", "application/json")
req.Header.Set(util.SeldonModelHeader, test.modelExternalHeader)
req.Header.Set(util.SeldonInternalModelHeader, test.modelToRequest)
createRequest := func(endpoint string) *http.Request {
inferV2Path := "/v2/models/RANDOM/" + endpoint
logger.Debug("inferV2Path:", inferV2Path)

url := "http://localhost:" + strconv.Itoa(rpPort) + inferV2Path
req, err := http.NewRequest(http.MethodPost, url, nil)
g.Expect(err).To(BeNil())
req.Header.Set("contentType", "application/json")
req.Header.Set(util.SeldonModelHeader, test.modelExternalHeader)
req.Header.Set(util.SeldonInternalModelHeader, test.modelToRequest)
return req
}

// infer request
req := createRequest("infer")
resp, err := http.DefaultClient.Do(req)
g.Expect(err).To(BeNil())

Expand All @@ -284,6 +320,25 @@ func TestReverseProxySmoke(t *testing.T) {
g.Expect(strings.Contains(bodyString, test.modelToLoad)).To(BeTrue())
}

// infer_stream request
req = createRequest("infer_stream")
resp, err = http.DefaultClient.Do(req)
g.Expect(err).To(BeNil())

g.Expect(resp.StatusCode).To(Equal(test.statusCode))
if test.statusCode == http.StatusOK {
scanner := bufio.NewScanner(resp.Body)
messages := make([]string, 0)
for scanner.Scan() {
messages = append(messages, scanner.Text())
}

g.Expect(scanner.Err()).To(BeNil())

messages_concat := strings.Join(messages, "")
g.Expect(strings.Contains(messages_concat, test.modelToLoad)).To(BeTrue())
}

// test model scaling stats
if test.statusCode == http.StatusOK {
g.Expect(rpHTTP.modelScalingStatsCollector.ModelLagStats.Get(test.modelToRequest)).To(Equal(uint32(0)))
Expand Down
Loading