Skip to content
23 changes: 21 additions & 2 deletions exec/bigmachine.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ type bigmachineExecutor struct {
// If the task is marked as exclusive, then one is added to their
// manager index.
managers []*machineManager

shutdownc chan struct{}
managersWG sync.WaitGroup
}

func newBigmachineExecutor(system bigmachine.System, params ...bigmachine.Param) *bigmachineExecutor {
Expand Down Expand Up @@ -150,13 +153,23 @@ func (b *bigmachineExecutor) Start(sess *Session) (shutdown func()) {
b.worker = &worker{
MachineCombiners: sess.machineCombiners,
}
b.shutdownc = make(chan struct{})

return b.b.Shutdown
return func() {
close(b.shutdownc)
b.managersWG.Wait()
b.b.Shutdown()
}
}

func (b *bigmachineExecutor) manager(i int) *machineManager {
b.mu.Lock()
defer b.mu.Unlock()
ctx, cancel := context.WithCancel(backgroundcontext.Get())
go func() {
<-b.shutdownc
cancel()
}()
for i >= len(b.managers) {
b.managers = append(b.managers, nil)
}
Expand All @@ -168,7 +181,13 @@ func (b *bigmachineExecutor) manager(i int) *machineManager {
maxLoad = 0
}
b.managers[i] = newMachineManager(b.b, b.params, b.status, b.sess.Parallelism(), maxLoad, b.worker)
go b.managers[i].Do(backgroundcontext.Get())

b.managersWG.Add(1)
go func() {
b.managers[i].Do(ctx)
b.managers[i].machinesWG.Wait()
b.managersWG.Done()
}()
}
return b.managers[i]
}
Expand Down
14 changes: 10 additions & 4 deletions exec/slicemachine.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"sync"
"time"

"github.com/grailbio/base/backgroundcontext"
"github.com/grailbio/base/data"
"github.com/grailbio/base/errors"
"github.com/grailbio/base/log"
Expand Down Expand Up @@ -350,6 +349,8 @@ type machineManager struct {
schedQ scheduleRequestQ
schedc chan scheduleRequest
unschedc chan scheduleRequest

machinesWG sync.WaitGroup
}

// NewMachineManager returns a new machineManager paramterized by the
Expand Down Expand Up @@ -549,6 +550,14 @@ func (m *machineManager) Do(ctx context.Context) {
have/m.machprocs, have, pending/m.machprocs, pending)
go func() {
machines := startMachines(ctx, m.b, m.group, m.machprocs, needMachines, m.worker, m.params...)
for _, machine := range machines {
machine := machine
m.machinesWG.Add(1)
go func() {
machine.Go(ctx)
m.machinesWG.Done()
}()
}
startc <- startResult{
machines: machines,
nFailures: needMachines - len(machines),
Expand Down Expand Up @@ -636,9 +645,6 @@ func startMachines(ctx context.Context, b *bigmachine.B, group *status.Group, ma
Status: status,
maxTaskProcs: maxTaskProcs,
}
// TODO(marius): pass a context that's tied to the evaluation
// lifetime, or lifetime of the machine.
go sm.Go(backgroundcontext.Get())
slicemachines[i] = sm
}()
}
Expand Down