Skip to content

Commit 4b233f9

Browse files
author
Sherif Akoush
authored
feat(scheduler): Add max elapsed duration for model load/unload (#5819)
* add max load elapsed time in client settings * add default maz elapsed time to 2 hours * increase default a single load operation timeout to an hour * adjust test after api change * remove outdate comment * add max unload elapsed time, defaulting to 15 minutes including retries. * add test coverage * fix fmt * fix spelling mistake * reduce the numner of retries to 5 by default * add rename test
1 parent 27b2858 commit 4b233f9

File tree

6 files changed

+49
-23
lines changed

6 files changed

+49
-23
lines changed

scheduler/cmd/agent/main.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ const (
4747
maxElapsedTimeReadySubServiceBeforeStart = 15 * time.Minute // 15 mins is the default MaxElapsedTime
4848
// period for subservice ready "cron"
4949
periodReadySubService = 60 * time.Second
50+
// max time to wait for a model server to load a model, including retries
51+
maxLoadElapsedTime = 120 * time.Minute
52+
// max time to wait for a model server to unload a model, including retries
53+
maxUnloadElapsedTime = 15 * time.Minute // 15 mins is the default MaxElapsedTime
5054
// number of retries for loading a model onto a server
51-
maxLoadRetryCount = 10
55+
maxLoadRetryCount = 5
5256
// number of retries for unloading a model onto a server
5357
maxUnloadRetryCount = 1
5458
)
@@ -275,6 +279,8 @@ func main() {
275279
periodReadySubService,
276280
maxElapsedTimeReadySubServiceBeforeStart,
277281
maxElapsedTimeReadySubServiceAfterStart,
282+
maxLoadElapsedTime,
283+
maxUnloadElapsedTime,
278284
maxLoadRetryCount,
279285
maxUnloadRetryCount,
280286
),

scheduler/pkg/agent/client.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ type ClientSettings struct {
8888
periodReadySubService time.Duration
8989
maxElapsedTimeReadySubServiceBeforeStart time.Duration
9090
maxElapsedTimeReadySubServiceAfterStart time.Duration
91+
maxLoadElapsedTime time.Duration
92+
maxUnloadElapsedTime time.Duration
9193
maxLoadRetryCount uint8
9294
maxUnloadRetryCount uint8
9395
}
@@ -100,7 +102,9 @@ func NewClientSettings(
100102
schedulerTlsPort int,
101103
periodReadySubService,
102104
maxElapsedTimeReadySubServiceBeforeStart,
103-
maxElapsedTimeReadySubServiceAfterStart time.Duration,
105+
maxElapsedTimeReadySubServiceAfterStart,
106+
maxLoadElapsedTime,
107+
maxUnloadElapsedTime time.Duration,
104108
maxLoadRetryCount,
105109
maxUnloadRetryCount uint8,
106110
) *ClientSettings {
@@ -113,6 +117,8 @@ func NewClientSettings(
113117
periodReadySubService: periodReadySubService,
114118
maxElapsedTimeReadySubServiceBeforeStart: maxElapsedTimeReadySubServiceBeforeStart,
115119
maxElapsedTimeReadySubServiceAfterStart: maxElapsedTimeReadySubServiceAfterStart,
120+
maxLoadElapsedTime: maxLoadElapsedTime,
121+
maxUnloadElapsedTime: maxUnloadElapsedTime,
116122
maxLoadRetryCount: maxLoadRetryCount,
117123
maxUnloadRetryCount: maxUnloadRetryCount,
118124
}
@@ -598,7 +604,7 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
598604
loaderFn := func() error {
599605
return c.stateManager.LoadModelVersion(modifiedModelVersionRequest)
600606
}
601-
if err := backoffWithMaxNumRetry(loaderFn, c.settings.maxLoadRetryCount, logger); err != nil {
607+
if err := backoffWithMaxNumRetry(loaderFn, c.settings.maxLoadRetryCount, c.settings.maxLoadElapsedTime, logger); err != nil {
602608
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, err)
603609
return err
604610
}
@@ -640,7 +646,7 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage) error {
640646
unloaderFn := func() error {
641647
return c.stateManager.UnloadModelVersion(modifiedModelVersionRequest)
642648
}
643-
if err := backoffWithMaxNumRetry(unloaderFn, c.settings.maxUnloadRetryCount, logger); err != nil {
649+
if err := backoffWithMaxNumRetry(unloaderFn, c.settings.maxUnloadRetryCount, c.settings.maxUnloadElapsedTime, logger); err != nil {
644650
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_UNLOAD_FAILED, err)
645651
return err
646652
}

scheduler/pkg/agent/client_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ func TestClientCreate(t *testing.T) {
206206
drainerServicePort, _ := testing_utils2.GetFreePortForTest()
207207
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
208208
client := NewClient(
209-
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
209+
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
210210
logger, modelRepository, v2Client,
211211
test.replicaConfig, "default",
212212
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())
@@ -365,7 +365,7 @@ func TestLoadModel(t *testing.T) {
365365
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
366366

367367
client := NewClient(
368-
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
368+
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
369369
logger, modelRepository, v2Client, test.replicaConfig, "default",
370370
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())
371371

@@ -514,7 +514,7 @@ parameters:
514514
drainerServicePort, _ := testing_utils2.GetFreePortForTest()
515515
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
516516
client := NewClient(
517-
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
517+
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
518518
logger, modelRepository,
519519
v2Client, test.replicaConfig, "default",
520520
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService,
@@ -656,7 +656,7 @@ func TestUnloadModel(t *testing.T) {
656656
drainerServicePort, _ := testing_utils2.GetFreePortForTest()
657657
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
658658
client := NewClient(
659-
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
659+
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
660660
logger, modelRepository, v2Client, test.replicaConfig, "default",
661661
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())
662662
mockAgentV2Server := &mockAgentV2Server{models: []string{}}
@@ -714,7 +714,7 @@ func TestClientClose(t *testing.T) {
714714
drainerServicePort, _ := testing_utils2.GetFreePortForTest()
715715
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
716716
client := NewClient(
717-
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
717+
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
718718
logger, modelRepository, v2Client,
719719
&pb.ReplicaConfig{MemoryBytes: 1000}, "default",
720720
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())
@@ -813,7 +813,7 @@ func TestAgentStopOnSubServicesFailure(t *testing.T) {
813813
_ = drainerService.Start()
814814
}()
815815
client := NewClient(
816-
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, period, maxTimeBeforeStart, maxTimeAfterStart, 1, 1),
816+
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, period, maxTimeBeforeStart, maxTimeAfterStart, 1*time.Minute, 1*time.Minute, 1, 1),
817817
logger, modelRepository, v2Client,
818818
&pb.ReplicaConfig{MemoryBytes: 1000}, "default",
819819
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())

scheduler/pkg/agent/client_utils.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,8 @@ func isReadyChecker(
8181
return nil
8282
}
8383

84-
func backoffWithMaxNumRetry(fn func() error, count uint8, logger log.FieldLogger) error {
85-
backoffWithMax := backoff.NewExponentialBackOff()
86-
// Wait for model repo to be ready
84+
func backoffWithMaxNumRetry(fn func() error, count uint8, maxElapsedTime time.Duration, logger log.FieldLogger) error {
85+
backoffWithMax := backoff.NewExponentialBackOff(backoff.WithMaxElapsedTime(maxElapsedTime))
8786
i := 0
8887
logFailure := func(err error, delay time.Duration) {
8988
logger.WithError(err).Errorf("Retry op #%d", i)
@@ -112,7 +111,7 @@ func (b *backOffWithMaxCount) Reset() {
112111
}
113112

114113
func (b *backOffWithMaxCount) NextBackOff() time.Duration {
115-
if b.currentCount >= b.maxCount {
114+
if b.currentCount >= b.maxCount-1 {
116115
return backoff.Stop
117116
} else {
118117
b.currentCount++

scheduler/pkg/agent/client_utils_test.go

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
log "github.com/sirupsen/logrus"
2020
)
2121

22-
func TestBackOffPolicyWithMax(t *testing.T) {
22+
func TestBackOffPolicyWithMaxCount(t *testing.T) {
2323
t.Logf("Started")
2424
logger := log.New()
2525
log.SetLevel(log.DebugLevel)
@@ -54,7 +54,7 @@ func TestBackOffPolicyWithMax(t *testing.T) {
5454
fn := func() error {
5555
return test.err
5656
}
57-
count := uint8(0)
57+
count := uint8(1) // first call is not a retry
5858
policyWithMax := newBackOffWithMaxCount(test.count, &policy)
5959
logFailure := func(err error, delay time.Duration) {
6060
logger.WithError(err).Errorf("retry")
@@ -66,7 +66,7 @@ func TestBackOffPolicyWithMax(t *testing.T) {
6666
if test.err != nil {
6767
g.Expect(count).To(Equal(test.count))
6868
} else {
69-
g.Expect(count).To(Equal(uint8(0)))
69+
g.Expect(count).To(Equal(uint8(1)))
7070
}
7171
})
7272
}
@@ -76,26 +76,41 @@ func TestFnWrapperWithMax(t *testing.T) {
7676
t.Logf("Started")
7777
logger := log.New()
7878
log.SetLevel(log.DebugLevel)
79+
g := NewGomegaWithT(t)
7980

8081
type test struct {
81-
name string
82-
count uint8
82+
name string
83+
count uint8
84+
maxElapsedTime time.Duration
85+
expectedCount uint8
8386
}
8487
tests := []test{
8588
{
86-
name: "simple",
87-
count: 3,
89+
name: "count > maxElapsedTime",
90+
count: 4,
91+
expectedCount: 4,
92+
maxElapsedTime: 30 * time.Second,
93+
},
94+
{
95+
name: "count < maxElapsedTime",
96+
count: 4,
97+
expectedCount: 1,
98+
maxElapsedTime: 1 * time.Millisecond,
8899
},
89100
}
90101

91102
for _, test := range tests {
92103
t.Run(test.name, func(t *testing.T) {
93104

105+
retries := uint8(0)
94106
fn := func() error {
107+
time.Sleep(1 * time.Millisecond)
108+
retries++
95109
return fmt.Errorf("error")
96110
}
97-
_ = backoffWithMaxNumRetry(fn, test.count, logger)
111+
_ = backoffWithMaxNumRetry(fn, test.count, test.maxElapsedTime, logger)
98112
// if we are here we are done
113+
g.Expect(retries).To(Equal(test.expectedCount))
99114
})
100115
}
101116
}

scheduler/pkg/util/constants.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const (
2727
GRPCRetryMaxCount = 5 // around 3.2s in total wait duration
2828
GRPCMaxMsgSizeBytes = 1000 * 1024 * 1024
2929
EnvoyUpdateDefaultBatchWait = 250 * time.Millisecond
30-
GRPCModelServerLoadTimeout = 30 * time.Minute // How long to wait for a model to load? think of LLM Load, maybe should be a config
30+
GRPCModelServerLoadTimeout = 60 * time.Minute // How long to wait for a model to load? think of LLM Load, maybe should be a config
3131
GRPCModelServerUnloadTimeout = 2 * time.Minute
3232
GRPCControlPlaneTimeout = 1 * time.Minute // For control plane operations except load/unload
3333
)

0 commit comments

Comments
 (0)