Skip to content

feat(scheduler): Add max elapsed duration for model load/unload #5819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion scheduler/cmd/agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ const (
maxElapsedTimeReadySubServiceBeforeStart = 15 * time.Minute // 15 mins is the default MaxElapsedTime
// period for subservice ready "cron"
periodReadySubService = 60 * time.Second
// max time to wait for a model server to load a model, including retries
maxLoadElapsedTime = 120 * time.Minute
// max time to wait for a model server to unload a model, including retries
maxUnloadElapsedTime = 15 * time.Minute // 15 mins is the default MaxElapsedTime
// number of retries for loading a model onto a server
maxLoadRetryCount = 10
maxLoadRetryCount = 5
// number of retries for unloading a model onto a server
maxUnloadRetryCount = 1
)
Expand Down Expand Up @@ -275,6 +279,8 @@ func main() {
periodReadySubService,
maxElapsedTimeReadySubServiceBeforeStart,
maxElapsedTimeReadySubServiceAfterStart,
maxLoadElapsedTime,
maxUnloadElapsedTime,
maxLoadRetryCount,
maxUnloadRetryCount,
),
Expand Down
12 changes: 9 additions & 3 deletions scheduler/pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ type ClientSettings struct {
periodReadySubService time.Duration
maxElapsedTimeReadySubServiceBeforeStart time.Duration
maxElapsedTimeReadySubServiceAfterStart time.Duration
maxLoadElapsedTime time.Duration
maxUnloadElapsedTime time.Duration
maxLoadRetryCount uint8
maxUnloadRetryCount uint8
}
Expand All @@ -100,7 +102,9 @@ func NewClientSettings(
schedulerTlsPort int,
periodReadySubService,
maxElapsedTimeReadySubServiceBeforeStart,
maxElapsedTimeReadySubServiceAfterStart time.Duration,
maxElapsedTimeReadySubServiceAfterStart,
maxLoadElapsedTime,
maxUnloadElapsedTime time.Duration,
maxLoadRetryCount,
maxUnloadRetryCount uint8,
) *ClientSettings {
Expand All @@ -113,6 +117,8 @@ func NewClientSettings(
periodReadySubService: periodReadySubService,
maxElapsedTimeReadySubServiceBeforeStart: maxElapsedTimeReadySubServiceBeforeStart,
maxElapsedTimeReadySubServiceAfterStart: maxElapsedTimeReadySubServiceAfterStart,
maxLoadElapsedTime: maxLoadElapsedTime,
maxUnloadElapsedTime: maxUnloadElapsedTime,
maxLoadRetryCount: maxLoadRetryCount,
maxUnloadRetryCount: maxUnloadRetryCount,
}
Expand Down Expand Up @@ -598,7 +604,7 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
loaderFn := func() error {
return c.stateManager.LoadModelVersion(modifiedModelVersionRequest)
}
if err := backoffWithMaxNumRetry(loaderFn, c.settings.maxLoadRetryCount, logger); err != nil {
if err := backoffWithMaxNumRetry(loaderFn, c.settings.maxLoadRetryCount, c.settings.maxLoadElapsedTime, logger); err != nil {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, err)
return err
}
Expand Down Expand Up @@ -640,7 +646,7 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage) error {
unloaderFn := func() error {
return c.stateManager.UnloadModelVersion(modifiedModelVersionRequest)
}
if err := backoffWithMaxNumRetry(unloaderFn, c.settings.maxUnloadRetryCount, logger); err != nil {
if err := backoffWithMaxNumRetry(unloaderFn, c.settings.maxUnloadRetryCount, c.settings.maxUnloadElapsedTime, logger); err != nil {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_UNLOAD_FAILED, err)
return err
}
Expand Down
12 changes: 6 additions & 6 deletions scheduler/pkg/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func TestClientCreate(t *testing.T) {
drainerServicePort, _ := testing_utils2.GetFreePortForTest()
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
client := NewClient(
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
logger, modelRepository, v2Client,
test.replicaConfig, "default",
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())
Expand Down Expand Up @@ -366,7 +366,7 @@ func TestLoadModel(t *testing.T) {
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))

client := NewClient(
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
logger, modelRepository, v2Client, test.replicaConfig, "default",
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())

Expand Down Expand Up @@ -515,7 +515,7 @@ parameters:
drainerServicePort, _ := testing_utils2.GetFreePortForTest()
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
client := NewClient(
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
logger, modelRepository,
v2Client, test.replicaConfig, "default",
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService,
Expand Down Expand Up @@ -657,7 +657,7 @@ func TestUnloadModel(t *testing.T) {
drainerServicePort, _ := testing_utils2.GetFreePortForTest()
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
client := NewClient(
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
logger, modelRepository, v2Client, test.replicaConfig, "default",
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())
mockAgentV2Server := &mockAgentV2Server{models: []string{}}
Expand Down Expand Up @@ -715,7 +715,7 @@ func TestClientClose(t *testing.T) {
drainerServicePort, _ := testing_utils2.GetFreePortForTest()
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
client := NewClient(
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1),
logger, modelRepository, v2Client,
&pb.ReplicaConfig{MemoryBytes: 1000}, "default",
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())
Expand Down Expand Up @@ -815,7 +815,7 @@ func TestAgentStopOnSubServicesFailure(t *testing.T) {
_ = drainerService.Start()
}()
client := NewClient(
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, period, maxTimeBeforeStart, maxTimeAfterStart, 1, 1),
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, period, maxTimeBeforeStart, maxTimeAfterStart, 1*time.Minute, 1*time.Minute, 1, 1),
logger, modelRepository, v2Client,
&pb.ReplicaConfig{MemoryBytes: 1000}, "default",
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())
Expand Down
7 changes: 3 additions & 4 deletions scheduler/pkg/agent/client_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ func isReadyChecker(
return nil
}

func backoffWithMaxNumRetry(fn func() error, count uint8, logger log.FieldLogger) error {
backoffWithMax := backoff.NewExponentialBackOff()
// Wait for model repo to be ready
func backoffWithMaxNumRetry(fn func() error, count uint8, maxElapsedTime time.Duration, logger log.FieldLogger) error {
backoffWithMax := backoff.NewExponentialBackOff(backoff.WithMaxElapsedTime(maxElapsedTime))
i := 0
logFailure := func(err error, delay time.Duration) {
logger.WithError(err).Errorf("Retry op #%d", i)
Expand Down Expand Up @@ -112,7 +111,7 @@ func (b *backOffWithMaxCount) Reset() {
}

func (b *backOffWithMaxCount) NextBackOff() time.Duration {
if b.currentCount >= b.maxCount {
if b.currentCount >= b.maxCount-1 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm: does this now retry once if b.maxCount is one? or is it that we now consider the initial function call as part of the "maxCount"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backoff.RetryNotify will run the function at least once and if there is an error returned will then use a backoff policy to decide on the retries. From that perspective the above code was slightly wrong and now fixed in the test as well TestBackOffPolicyWithMaxCount.

Copy link
Member

@lc525 lc525 Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the question is whether for a config parameter named something like (max*RetryCount) one would expect a maximum number of max*RetryCount retries, or they would expect that the function runs a total of max*RetryCount times. Either way should be fine as long as we're explicit to users (when this becomes configurable externally)

Copy link
Contributor Author

@sakoush sakoush Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a ticket to create docs to clarify these semantic differences and what we actually do in core 2.
Currently we have: one would expect a maximum number of max*RetryCount retries bounded by the maximum Elapsed duration

return backoff.Stop
} else {
b.currentCount++
Expand Down
31 changes: 23 additions & 8 deletions scheduler/pkg/agent/client_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
log "github.com/sirupsen/logrus"
)

func TestBackOffPolicyWithMax(t *testing.T) {
func TestBackOffPolicyWithMaxCount(t *testing.T) {
t.Logf("Started")
logger := log.New()
log.SetLevel(log.DebugLevel)
Expand Down Expand Up @@ -54,7 +54,7 @@ func TestBackOffPolicyWithMax(t *testing.T) {
fn := func() error {
return test.err
}
count := uint8(0)
count := uint8(1) // first call is not a retry
policyWithMax := newBackOffWithMaxCount(test.count, &policy)
logFailure := func(err error, delay time.Duration) {
logger.WithError(err).Errorf("retry")
Expand All @@ -66,7 +66,7 @@ func TestBackOffPolicyWithMax(t *testing.T) {
if test.err != nil {
g.Expect(count).To(Equal(test.count))
} else {
g.Expect(count).To(Equal(uint8(0)))
g.Expect(count).To(Equal(uint8(1)))
}
})
}
Expand All @@ -76,26 +76,41 @@ func TestFnWrapperWithMax(t *testing.T) {
t.Logf("Started")
logger := log.New()
log.SetLevel(log.DebugLevel)
g := NewGomegaWithT(t)

type test struct {
name string
count uint8
name string
count uint8
maxElapsedTime time.Duration
expectedCount uint8
}
tests := []test{
{
name: "simple",
count: 3,
name: "count > maxElapsedTime",
count: 4,
expectedCount: 4,
maxElapsedTime: 30 * time.Second,
},
{
name: "count < maxElapsedTime",
count: 4,
expectedCount: 1,
maxElapsedTime: 1 * time.Millisecond,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {

retries := uint8(0)
fn := func() error {
time.Sleep(1 * time.Millisecond)
retries++
return fmt.Errorf("error")
}
_ = backoffWithMaxNumRetry(fn, test.count, logger)
_ = backoffWithMaxNumRetry(fn, test.count, test.maxElapsedTime, logger)
// if we are here we are done
g.Expect(retries).To(Equal(test.expectedCount))
})
}
}
2 changes: 1 addition & 1 deletion scheduler/pkg/util/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const (
GRPCRetryMaxCount = 5 // around 3.2s in total wait duration
GRPCMaxMsgSizeBytes = 1000 * 1024 * 1024
EnvoyUpdateDefaultBatchWait = 250 * time.Millisecond
GRPCModelServerLoadTimeout = 30 * time.Minute // How long to wait for a model to load? think of LLM Load, maybe should be a config
GRPCModelServerLoadTimeout = 60 * time.Minute // How long to wait for a model to load? think of LLM Load, maybe should be a config
GRPCModelServerUnloadTimeout = 2 * time.Minute
GRPCControlPlaneTimeout = 1 * time.Minute // For control plane operations except load/unload
)
Loading