Skip to content

Commit 113f3ab

Browse files
author
Sherif Akoush
authored
fix(agent): set context deadline for grpc model server control plane (SeldonIO#5329)
* rename and introduce constants * add timeout our context for v2 ops load/unload * increase time * introduce v2config class * add test * add deadline context for Live and RepositoryIndexRequest calls * add timeout for other control plane ops on model servers * add test for grpc defaults
1 parent 3c59a48 commit 113f3ab

File tree

13 files changed

+158
-49
lines changed

13 files changed

+158
-49
lines changed

scheduler/pkg/agent/client_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,8 @@ func TestAgentStopOnSubServicesFailure(t *testing.T) {
797797

798798
time.Sleep(50 * time.Millisecond)
799799

800-
v2Client := oip.NewV2Client("", backEndGRPCPort, log.New())
800+
v2Client := oip.NewV2Client(
801+
oip.GetV2ConfigWithDefaults("", backEndGRPCPort), log.New())
801802

802803
modelRepository := FakeModelRepository{}
803804
rpHTTP := FakeDependencyService{err: nil}

scheduler/pkg/agent/internal/testing_utils/mock_grpc_server.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"net"
1616
"net/http"
1717
"sync"
18+
"time"
1819

1920
"github.com/jarcoal/httpmock"
2021
"google.golang.org/grpc"
@@ -74,10 +75,13 @@ func (s *V2State) IsModelLoaded(modelId string) bool {
7475
}
7576

7677
type MockGRPCMLServer struct {
77-
listener net.Listener
78-
server *grpc.Server
79-
models []interfaces.ServerModelInfo
80-
isReady bool
78+
listener net.Listener
79+
server *grpc.Server
80+
models []interfaces.ServerModelInfo
81+
isReady bool
82+
LoadSleep time.Duration
83+
UnloadSleep time.Duration
84+
ControlPlaneSleep time.Duration
8185
v2.UnimplementedGRPCInferenceServiceServer
8286
}
8387

@@ -120,21 +124,29 @@ func (m *MockGRPCMLServer) ServerReady(ctx context.Context, r *v2.ServerReadyReq
120124
}
121125

122126
func (m *MockGRPCMLServer) ServerLive(ctx context.Context, r *v2.ServerLiveRequest) (*v2.ServerLiveResponse, error) {
127+
// by default ControlPlaneSleep is 0
128+
time.Sleep(m.ControlPlaneSleep)
123129
return &v2.ServerLiveResponse{Live: true}, nil
124130
}
125131

126132
func (m *MockGRPCMLServer) RepositoryModelLoad(ctx context.Context, r *v2.RepositoryModelLoadRequest) (*v2.RepositoryModelLoadResponse, error) {
133+
// by default LoadSleep is 0
134+
time.Sleep(m.LoadSleep)
127135
return &v2.RepositoryModelLoadResponse{}, nil
128136
}
129137

130138
func (m *MockGRPCMLServer) RepositoryModelUnload(ctx context.Context, r *v2.RepositoryModelUnloadRequest) (*v2.RepositoryModelUnloadResponse, error) {
139+
// by default UnloadSleep is 0
140+
time.Sleep(m.UnloadSleep)
131141
if r.ModelName == ModelNameMissing {
132142
return nil, status.Error(codes.NotFound, fmt.Sprintf("Model %s not found", r.ModelName))
133143
}
134144
return &v2.RepositoryModelUnloadResponse{}, nil
135145
}
136146

137147
func (m *MockGRPCMLServer) RepositoryIndex(ctx context.Context, r *v2.RepositoryIndexRequest) (*v2.RepositoryIndexResponse, error) {
148+
// by default ControlPlaneSleep is 0
149+
time.Sleep(m.ControlPlaneSleep)
138150
ret := make([]*v2.RepositoryIndexResponse_ModelIndex, len(m.models))
139151
for idx, model := range m.models {
140152
ret[idx] = &v2.RepositoryIndexResponse_ModelIndex{Name: model.Name, State: string(model.State)}

scheduler/pkg/agent/modelserver_controlplane/factory/factory.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ func CreateModelServerControlPlane(
1919
config interfaces.ModelServerConfig,
2020
) (interfaces.ModelServerControlPlaneClient, error) {
2121
// we only support v2 for now
22-
return oip.NewV2Client(config.Host, config.Port, config.Logger), nil
22+
return oip.NewV2Client(
23+
oip.GetV2ConfigWithDefaults(config.Host, config.Port), config.Logger), nil
2324
}

scheduler/pkg/agent/modelserver_controlplane/oip/v2.go

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,46 @@ import (
2727
"github.com/seldonio/seldon-core/scheduler/v2/pkg/util"
2828
)
2929

30+
type V2Config struct {
31+
Host string
32+
GRPCPort int
33+
GRPCRetryBackoff time.Duration
34+
GRPRetryMaxCount uint
35+
GRPCMaxMsgSizeBytes int
36+
GRPCModelServerLoadTimeout time.Duration
37+
GRPCModelServerUnloadTimeout time.Duration
38+
GRPCControlPlaneTimeout time.Duration
39+
}
40+
3041
type V2Client struct {
3142
grpcClient v2.GRPCInferenceServiceClient
32-
host string
33-
grpcPort int
43+
v2Config V2Config
3444
logger log.FieldLogger
3545
}
3646

37-
func CreateV2GrpcConnection(host string, plainTxtPort int) (*grpc.ClientConn, error) {
47+
func CreateV2GrpcConnection(v2Config V2Config) (*grpc.ClientConn, error) {
3848
retryOpts := []grpc_retry.CallOption{
39-
grpc_retry.WithBackoff(grpc_retry.BackoffExponential(util.GrpcRetryBackoffMillisecs * time.Millisecond)),
40-
grpc_retry.WithMax(util.GrpcRetryMaxCount),
49+
grpc_retry.WithBackoff(grpc_retry.BackoffExponential(v2Config.GRPCRetryBackoff)),
50+
grpc_retry.WithMax(v2Config.GRPRetryMaxCount),
4151
}
4252

4353
opts := []grpc.DialOption{
4454
grpc.WithTransportCredentials(insecure.NewCredentials()),
45-
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(util.GrpcMaxMsgSizeBytes), grpc.MaxCallSendMsgSize(util.GrpcMaxMsgSizeBytes)),
55+
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(v2Config.GRPCMaxMsgSizeBytes), grpc.MaxCallSendMsgSize(v2Config.GRPCMaxMsgSizeBytes)),
4656
grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor(retryOpts...)),
4757
grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(retryOpts...)),
4858
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
4959
}
50-
conn, err := grpc.Dial(fmt.Sprintf("%s:%d", host, plainTxtPort), opts...)
60+
conn, err := grpc.Dial(fmt.Sprintf("%s:%d", v2Config.Host, v2Config.GRPCPort), opts...)
5161
if err != nil {
5262
return nil, err
5363
}
5464

5565
return conn, nil
5666
}
5767

58-
func createV2ControlPlaneClient(host string, port int) (v2.GRPCInferenceServiceClient, error) {
59-
conn, err := CreateV2GrpcConnection(host, port)
68+
func createV2ControlPlaneClient(v2Config V2Config) (v2.GRPCInferenceServiceClient, error) {
69+
conn, err := CreateV2GrpcConnection(v2Config)
6070
if err != nil {
6171
// TODO: this could fail in later iterations, so close earlier connections
6272
conn.Close()
@@ -67,17 +77,29 @@ func createV2ControlPlaneClient(host string, port int) (v2.GRPCInferenceServiceC
6777
return client, nil
6878
}
6979

70-
func NewV2Client(host string, port int, logger log.FieldLogger) *V2Client {
71-
logger.Infof("V2 (OIP) Inference Server %s:%d", host, port)
80+
func GetV2ConfigWithDefaults(host string, port int) V2Config {
81+
return V2Config{
82+
Host: host,
83+
GRPCPort: port,
84+
GRPCRetryBackoff: util.GRPCRetryBackoff,
85+
GRPRetryMaxCount: util.GRPCRetryMaxCount,
86+
GRPCMaxMsgSizeBytes: util.GRPCMaxMsgSizeBytes,
87+
GRPCModelServerLoadTimeout: util.GRPCModelServerLoadTimeout,
88+
GRPCModelServerUnloadTimeout: util.GRPCModelServerUnloadTimeout,
89+
GRPCControlPlaneTimeout: util.GRPCControlPlaneTimeout,
90+
}
91+
}
7292

73-
grpcClient, err := createV2ControlPlaneClient(host, port)
93+
func NewV2Client(v2Config V2Config, logger log.FieldLogger) *V2Client {
94+
logger.Infof("V2 (OIP) Inference Server %s:%d", v2Config.Host, v2Config.GRPCPort)
95+
96+
grpcClient, err := createV2ControlPlaneClient(v2Config)
7497
if err != nil {
7598
return nil
7699
}
77100

78101
return &V2Client{
79-
host: host,
80-
grpcPort: port,
102+
v2Config: v2Config,
81103
grpcClient: grpcClient,
82104
logger: logger.WithField("Source", "V2InferenceServerClientGrpc"),
83105
}
@@ -91,7 +113,8 @@ func (v *V2Client) LoadModel(name string) *interfaces.ControlPlaneErr {
91113
}
92114

93115
func (v *V2Client) loadModelGrpc(name string) *interfaces.ControlPlaneErr {
94-
ctx := context.Background()
116+
ctx, cancel := context.WithTimeout(context.Background(), v.v2Config.GRPCModelServerLoadTimeout)
117+
defer cancel()
95118

96119
req := &v2.RepositoryModelLoadRequest{
97120
ModelName: name,
@@ -122,7 +145,8 @@ func (v *V2Client) UnloadModel(name string) *interfaces.ControlPlaneErr {
122145
}
123146

124147
func (v *V2Client) unloadModelGrpc(name string) *interfaces.ControlPlaneErr {
125-
ctx := context.Background()
148+
ctx, cancel := context.WithTimeout(context.Background(), v.v2Config.GRPCModelServerUnloadTimeout)
149+
defer cancel()
126150

127151
req := &v2.RepositoryModelUnloadRequest{
128152
ModelName: name,
@@ -165,7 +189,9 @@ func (v *V2Client) Live() error {
165189
}
166190

167191
func (v *V2Client) liveGrpc() (bool, error) {
168-
ctx := context.Background()
192+
ctx, cancel := context.WithTimeout(context.Background(), v.v2Config.GRPCControlPlaneTimeout)
193+
defer cancel()
194+
169195
req := &v2.ServerLiveRequest{}
170196

171197
res, err := v.grpcClient.ServerLive(ctx, req)
@@ -180,8 +206,10 @@ func (v *V2Client) GetModels() ([]interfaces.ServerModelInfo, error) {
180206
}
181207

182208
func (v *V2Client) getModelsGrpc() ([]interfaces.ServerModelInfo, error) {
209+
ctx, cancel := context.WithTimeout(context.Background(), v.v2Config.GRPCControlPlaneTimeout)
210+
defer cancel()
211+
183212
var models []interfaces.ServerModelInfo
184-
ctx := context.Background()
185213
req := &v2.RepositoryIndexRequest{}
186214

187215
res, err := v.grpcClient.RepositoryIndex(ctx, req)

scheduler/pkg/agent/modelserver_controlplane/oip/v2_test.go

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ import (
1515

1616
. "github.com/onsi/gomega"
1717
log "github.com/sirupsen/logrus"
18+
"google.golang.org/grpc/codes"
19+
"google.golang.org/grpc/status"
1820

1921
"github.com/seldonio/seldon-core/scheduler/v2/pkg/agent/interfaces"
2022
"github.com/seldonio/seldon-core/scheduler/v2/pkg/agent/internal/testing_utils"
2123
testing_utils2 "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils"
24+
"github.com/seldonio/seldon-core/scheduler/v2/pkg/util"
2225
)
2326

2427
func TestCommunicationErrors(t *testing.T) {
@@ -30,7 +33,7 @@ func TestCommunicationErrors(t *testing.T) {
3033
g.Expect(err.ErrCode).To(Equal(interfaces.V2CommunicationErrCode))
3134
}
3235

33-
func TestGrpcV2(t *testing.T) {
36+
func TestGRPCV2(t *testing.T) {
3437
g := NewGomegaWithT(t)
3538

3639
mockMLServer := &testing_utils.MockGRPCMLServer{}
@@ -46,7 +49,7 @@ func TestGrpcV2(t *testing.T) {
4649

4750
time.Sleep(10 * time.Millisecond)
4851

49-
v2Client := NewV2Client("", backEndGRPCPort, log.New())
52+
v2Client := NewV2Client(GetV2ConfigWithDefaults("", backEndGRPCPort), log.New())
5053

5154
dummModel := "dummy"
5255

@@ -72,6 +75,67 @@ func TestGrpcV2(t *testing.T) {
7275

7376
}
7477

78+
func TestGRPCV2Timeout(t *testing.T) {
79+
g := NewGomegaWithT(t)
80+
81+
unloadSleep := 5 * time.Second
82+
loadSleep := 2 * time.Second
83+
controlPlaneSleep := 1 * time.Second
84+
mockMLServer := &testing_utils.MockGRPCMLServer{
85+
UnloadSleep: unloadSleep, LoadSleep: loadSleep, ControlPlaneSleep: controlPlaneSleep}
86+
backEndGRPCPort, err := testing_utils2.GetFreePortForTest()
87+
if err != nil {
88+
t.Fatal(err)
89+
}
90+
_ = mockMLServer.Setup(uint(backEndGRPCPort))
91+
go func() {
92+
_ = mockMLServer.Start()
93+
}()
94+
defer mockMLServer.Stop()
95+
96+
time.Sleep(10 * time.Millisecond)
97+
98+
v2Config := GetV2ConfigWithDefaults("", backEndGRPCPort)
99+
v2Config.GRPCModelServerUnloadTimeout = unloadSleep / 2
100+
v2Config.GRPCModelServerLoadTimeout = loadSleep / 2
101+
v2Config.GRPCControlPlaneTimeout = controlPlaneSleep / 2
102+
v2Client := NewV2Client(v2Config, log.New())
103+
104+
dummModel := "dummy"
105+
106+
v2Err := v2Client.LoadModel(dummModel)
107+
g.Expect(v2Err).NotTo(BeNil())
108+
g.Expect(v2Err.ErrCode).To(Equal(int(codes.DeadlineExceeded)))
109+
110+
v2Err = v2Client.UnloadModel(dummModel)
111+
g.Expect(v2Err).NotTo(BeNil())
112+
g.Expect(v2Err.ErrCode).To(Equal(int(codes.DeadlineExceeded)))
113+
114+
err = v2Client.Live()
115+
g.Expect(err).NotTo(BeNil())
116+
e, _ := status.FromError(err)
117+
g.Expect(e.Code()).To(Equal(codes.DeadlineExceeded))
118+
119+
_, err = v2Client.getModelsGrpc()
120+
g.Expect(err).NotTo(BeNil())
121+
e, _ = status.FromError(err)
122+
g.Expect(e.Code()).To(Equal(codes.DeadlineExceeded))
123+
}
124+
125+
func TestDefaultV2Config(t *testing.T) {
126+
g := NewGomegaWithT(t)
127+
128+
v2Config := GetV2ConfigWithDefaults("", 0)
129+
g.Expect(v2Config.GRPCModelServerLoadTimeout).To(Equal(util.GRPCModelServerLoadTimeout))
130+
g.Expect(v2Config.GRPCModelServerUnloadTimeout).To(Equal(util.GRPCModelServerUnloadTimeout))
131+
g.Expect(v2Config.GRPCMaxMsgSizeBytes).To(Equal(util.GRPCMaxMsgSizeBytes))
132+
g.Expect(v2Config.GRPCControlPlaneTimeout).To(Equal(util.GRPCControlPlaneTimeout))
133+
g.Expect(v2Config.GRPCRetryBackoff).To(Equal(util.GRPCRetryBackoff))
134+
g.Expect(v2Config.GRPRetryMaxCount).To(Equal(uint(util.GRPCRetryMaxCount)))
135+
g.Expect(v2Config.Host).To(Equal(""))
136+
g.Expect(v2Config.GRPCPort).To(Equal(0))
137+
}
138+
75139
func TestGrpcV2WithError(t *testing.T) {
76140
g := NewGomegaWithT(t)
77141

@@ -81,7 +145,7 @@ func TestGrpcV2WithError(t *testing.T) {
81145
if err != nil {
82146
t.Fatal(err)
83147
}
84-
v2Client := NewV2Client("", backEndGRPCPort, log.New())
148+
v2Client := NewV2Client(GetV2ConfigWithDefaults("", backEndGRPCPort), log.New())
85149

86150
dummModel := "dummy"
87151

@@ -110,7 +174,7 @@ func TestGrpcV2WithRetry(t *testing.T) {
110174
go func() {
111175
_ = mockMLServer.Start()
112176
}()
113-
v2Client := NewV2Client("", backEndGRPCPort, log.New())
177+
v2Client := NewV2Client(GetV2ConfigWithDefaults("", backEndGRPCPort), log.New())
114178
err = v2Client.Live()
115179
g.Expect(err).To(BeNil())
116180
mockMLServer.Stop()

scheduler/pkg/agent/rproxy_grpc.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ func (rp *reverseGRPCProxy) Start() error {
110110
opts = append(opts, grpc.Creds(rp.tlsOptions.Cert.CreateServerTransportCredentials()))
111111
}
112112
opts = append(opts, grpc.MaxConcurrentStreams(grpcProxyMaxConcurrentStreams))
113-
opts = append(opts, grpc.MaxRecvMsgSize(util.GrpcMaxMsgSizeBytes))
114-
opts = append(opts, grpc.MaxSendMsgSize(util.GrpcMaxMsgSizeBytes))
113+
opts = append(opts, grpc.MaxRecvMsgSize(util.GRPCMaxMsgSizeBytes))
114+
opts = append(opts, grpc.MaxSendMsgSize(util.GRPCMaxMsgSizeBytes))
115115
opts = append(opts, grpc.StatsHandler(otelgrpc.NewServerHandler()))
116116
opts = append(opts, grpc.UnaryInterceptor(rp.metrics.UnaryServerInterceptor()))
117117
grpcServer := grpc.NewServer(opts...)
@@ -322,7 +322,8 @@ func (rp *reverseGRPCProxy) createV2CRPCClients(backendGRPCServerHost string, ba
322322
return nil, nil, err
323323
}
324324
for i := 0; i < size; i++ {
325-
conn, err := oip.CreateV2GrpcConnection(backendGRPCServerHost, backendGRPCServerPort)
325+
conn, err := oip.CreateV2GrpcConnection(
326+
oip.GetV2ConfigWithDefaults(backendGRPCServerHost, backendGRPCServerPort))
326327

327328
if err != nil {
328329
// TODO: this could fail in later iterations, so close earlier connections

scheduler/pkg/agent/rproxy_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ func TestLazyLoadRoundTripper(t *testing.T) {
387387
_ = mlserver.ListenAndServe()
388388
}()
389389

390-
time.Sleep(util.GrpcRetryBackoffMillisecs * time.Millisecond)
390+
time.Sleep(util.GRPCRetryBackoff)
391391

392392
defer func() {
393393
_ = mlserver.Shutdown(context.Background())

scheduler/pkg/agent/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ func (s *Server) drainServerReplicaImpl(serverName string, serverReplicaIdx int)
489489
s.waiter.wait(serverName, serverReplicaIdx)
490490

491491
// as we update envoy in batches and envoy is eventual consistent, give it time to settle down
492-
time.Sleep(util.EnvoyUpdateDefaultBatchWaitMillis + (time.Millisecond * serverDrainingExtraWaitMillis))
492+
time.Sleep(util.EnvoyUpdateDefaultBatchWait + (time.Millisecond * serverDrainingExtraWaitMillis))
493493
s.logger.Debugf("Finished draining models %v from server %s:%d", modelsChanged, serverName, serverReplicaIdx)
494494
}
495495

0 commit comments

Comments
 (0)