Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
613 changes: 341 additions & 272 deletions apis/go/mlops/scheduler/scheduler.pb.go

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion apis/mlops/scheduler/scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,12 @@ message ControlPlaneSubscriptionRequest {
}

message ControlPlaneResponse {

enum Event {
UNKNOWN_EVENT = 0;
SEND_SERVERS = 1; // initial sync for the servers
SEND_RESOURCES = 2; // send models / pipelines / experiments
}
Event event = 1;
}

// [END Messages]
Expand Down
47 changes: 25 additions & 22 deletions operator/scheduler/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,34 +142,37 @@ func (s *SchedulerClient) startEventHanders(namespace string, conn *grpc.ClientC
}()
}

func (s *SchedulerClient) handleStateOnReconnect(context context.Context, grpcClient scheduler.SchedulerClient, namespace string) error {
// on new reconnects we send a list of servers to the schedule
err := s.handleRegisteredServers(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send registered server to scheduler")
}

if err == nil {
err = s.handleExperiments(context, grpcClient, namespace)
func (s *SchedulerClient) handleStateOnReconnect(context context.Context, grpcClient scheduler.SchedulerClient, namespace string, operation scheduler.ControlPlaneResponse_Event) error {
switch operation {
case scheduler.ControlPlaneResponse_SEND_SERVERS:
// on new reconnects we send a list of servers to the schedule
err := s.handleRegisteredServers(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send experiments to scheduler")
s.logger.Error(err, "Failed to send registered server to scheduler")
}
}

if err == nil {
err = s.handlePipelines(context, grpcClient, namespace)
return err
case scheduler.ControlPlaneResponse_SEND_RESOURCES:
err := s.handleExperiments(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send pipelines to scheduler")
s.logger.Error(err, "Failed to send experiments to scheduler")
}
}

if err == nil {
err = s.handleModels(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send models to scheduler")
if err == nil {
err = s.handlePipelines(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send pipelines to scheduler")
}
}
if err == nil {
err = s.handleModels(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send models to scheduler")
}
}
return err
default:
s.logger.Info("Unknown operation", "operation", operation)
return fmt.Errorf("Unknown operation %v", operation)
}
return err
}

func (s *SchedulerClient) RemoveConnection(namespace string) {
Expand Down
2 changes: 1 addition & 1 deletion operator/scheduler/control_plane.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (s *SchedulerClient) SubscribeControlPlaneEvents(ctx context.Context, grpcC
logger.Info("Received event to handle state", "event", event)

fn := func() error {
return s.handleStateOnReconnect(ctx, grpcClient, namespace)
return s.handleStateOnReconnect(ctx, grpcClient, namespace, event.GetEvent())
}
_, err = execWithTimeout(fn, execTimeOut)
if err != nil {
Expand Down
12 changes: 8 additions & 4 deletions operator/scheduler/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (s *mockSchedulerServerSubscribeGrpcClient) Recv() (*scheduler.ServerStatus
// Control Plane subscribe mock grpc client

type mockControlPlaneSubscribeGrpcClient struct {
sent bool
sent int
grpc.ClientStream
}

Expand All @@ -120,9 +120,13 @@ func newMockControlPlaneSubscribeGrpcClient() *mockControlPlaneSubscribeGrpcClie
}

func (s *mockControlPlaneSubscribeGrpcClient) Recv() (*scheduler.ControlPlaneResponse, error) {
if !s.sent {
s.sent = true
return &scheduler.ControlPlaneResponse{}, nil
if s.sent == 0 {
s.sent++
return &scheduler.ControlPlaneResponse{Event: scheduler.ControlPlaneResponse_SEND_SERVERS}, nil
}
if s.sent == 1 {
s.sent++
return &scheduler.ControlPlaneResponse{Event: scheduler.ControlPlaneResponse_SEND_RESOURCES}, nil
}
return nil, io.EOF
}
Expand Down
20 changes: 18 additions & 2 deletions scheduler/pkg/server/control_plane.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,15 @@ func (s *SchedulerServer) SubscribeControlPlane(req *pb.ControlPlaneSubscription
return err
}

fin := make(chan bool)
s.synchroniser.WaitReady()

err = s.sendResourcesMarker(stream)
if err != nil {
logger.WithError(err).Errorf("Failed to send resources marker to %s", req.GetSubscriberName())
return err
}

fin := make(chan bool)
s.controlPlaneStream.mu.Lock()
s.controlPlaneStream.streams[stream] = &ControlPlaneSubsription{
name: req.GetSubscriberName(),
Expand Down Expand Up @@ -61,11 +68,20 @@ func (s *SchedulerServer) StopSendControlPlaneEvents() {
// this is to mark the initial start of a new stream (at application level)
// as otherwise the other side sometimes doesnt know if the scheduler has established a new stream explicitly
func (s *SchedulerServer) sendStartServerStreamMarker(stream pb.Scheduler_SubscribeControlPlaneServer) error {
ssr := &pb.ControlPlaneResponse{}
ssr := &pb.ControlPlaneResponse{Event: pb.ControlPlaneResponse_SEND_SERVERS}
_, err := sendWithTimeout(func() error { return stream.Send(ssr) }, s.timeout)
if err != nil {
return err
}
return nil
}

// this is to mark a stage to send resources (models, pipelines, experiments) from the controller
func (s *SchedulerServer) sendResourcesMarker(stream pb.Scheduler_SubscribeControlPlaneServer) error {
ssr := &pb.ControlPlaneResponse{Event: pb.ControlPlaneResponse_SEND_RESOURCES}
_, err := sendWithTimeout(func() error { return stream.Send(ssr) }, s.timeout)
if err != nil {
return err
}
return nil
}
91 changes: 91 additions & 0 deletions scheduler/pkg/server/control_plane_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,22 @@ the Change License after the Change Date as each is defined in accordance with t
package server

import (
"context"
"fmt"
"testing"
"time"

. "github.com/onsi/gomega"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/store"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/synchroniser"
)

func TestStartServerStream(t *testing.T) {
Expand Down Expand Up @@ -67,7 +74,91 @@ func TestStartServerStream(t *testing.T) {
}

g.Expect(msr).ToNot(BeNil())
g.Expect(msr.Event).To(Equal(pb.ControlPlaneResponse_SEND_SERVERS))
}

err = test.server.sendResourcesMarker(stream)
if test.err {
g.Expect(err).ToNot(BeNil())
} else {
g.Expect(err).To(BeNil())

var msr *pb.ControlPlaneResponse
select {
case next := <-stream.msgs:
msr = next
default:
t.Fail()
}

g.Expect(msr).ToNot(BeNil())
g.Expect(msr.Event).To(Equal(pb.ControlPlaneResponse_SEND_RESOURCES))
}
})
}
}

func TestSubscribeControlPlane(t *testing.T) {
log.SetLevel(log.DebugLevel)
g := NewGomegaWithT(t)

type test struct {
name string
}
tests := []test{
{
name: "simple",
},
}

getStream := func(context context.Context, port int) (*grpc.ClientConn, pb.Scheduler_SubscribeControlPlaneClient) {
conn, _ := grpc.NewClient(fmt.Sprintf(":%d", port), grpc.WithTransportCredentials(insecure.NewCredentials()))
grpcClient := pb.NewSchedulerClient(conn)
client, _ := grpcClient.SubscribeControlPlane(
context,
&pb.ControlPlaneSubscriptionRequest{SubscriberName: "dummy"},
)
return conn, client
}

createTestScheduler := func() *SchedulerServer {
logger := log.New()
logger.SetLevel(log.WarnLevel)

eventHub, err := coordinator.NewEventHub(logger)
g.Expect(err).To(BeNil())

sync := synchroniser.NewSimpleSynchroniser(time.Duration(10 * time.Millisecond))

s := NewSchedulerServer(logger, nil, nil, nil, nil, eventHub, sync)
sync.Signals(1)

return s
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := createTestScheduler()
port, err := testing_utils.GetFreePortForTest()
if err != nil {
t.Fatal(err)
}
err = server.startServer(uint(port), false)
if err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)

conn, client := getStream(context.Background(), port)

msg, _ := client.Recv()
g.Expect(msg.GetEvent()).To(Equal(pb.ControlPlaneResponse_SEND_SERVERS))

msg, _ = client.Recv()
g.Expect(msg.Event).To(Equal(pb.ControlPlaneResponse_SEND_RESOURCES))

conn.Close()
server.StopSendControlPlaneEvents()
})
}
}
11 changes: 8 additions & 3 deletions scheduler/pkg/store/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ func (m *MemoryStore) addModelVersionIfNotExists(req *agent.ModelVersion) (*Mode
}

func (m *MemoryStore) addNextModelVersion(model *Model, pbmodel *pb.Model) {
version := uint32(1)
// if we start from a clean state, lets use the generation id as the starting version
// this is to ensure that we have monotonic increasing version numbers
// and we never reset back to 1
generation := pbmodel.GetMeta().GetKubernetesMeta().GetGeneration()
version := max(uint32(1), uint32(generation))
if model.Latest() != nil {
version = model.Latest().GetVersion() + 1
}
Expand Down Expand Up @@ -329,7 +333,7 @@ func (m *MemoryStore) updateLoadedModelsImpl(
modelVersion = model.Latest()
}

// resevere memory for existing replicas that are not already loading or loaded
// revere memory for existing replicas that are not already loading or loaded
replicaStateUpdated := false
for replicaIdx := range assignedReplicaIds {
if existingState, ok := modelVersion.replicas[replicaIdx]; !ok {
Expand Down Expand Up @@ -370,7 +374,8 @@ func (m *MemoryStore) updateLoadedModelsImpl(
// in cases where we did have a previous ScheduleFailed, we need to reflect the change here
// this could be in the cases where we are scaling down a model and the new replica count can be all deployed
// and always send an update for deleted models, so the operator will remove them from k8s
if replicaStateUpdated || modelVersion.state.State == ScheduleFailed || model.IsDeleted() {
// also send an update for progressing models so the operator can update the status in the case of a network glitch where the model generation has been updated
if replicaStateUpdated || modelVersion.state.State == ScheduleFailed || model.IsDeleted() || modelVersion.state.State == ModelProgressing {
logger.Debugf("Updating model status for model %s server %s", modelKey, serverKey)
modelVersion.server = serverKey
m.updateModelStatus(true, model.IsDeleted(), modelVersion, model.GetLastAvailableModelVersion())
Expand Down
Loading
Loading