Skip to content

Commit 5566339

Browse files
sywhangabhinav
andauthored
Ensure OnStart/OnStop hooks can only be called once (uber-go#931)
* Ensure OnStop hooks can only be called once It is possible for a user to erroneously call app.Run() then follow up by calling app.Stop(). In such a case, it is possible for the app.Stop() method to be called by two goroutines concurrently, resulting in a race. This adds a state in the App to keep track of whether Stop() has been invoked so that such a race can be prevented. Fix uber-go#930 Internal Ref: GO-1606 * use sync.Once and also add the check for OnStart hooks * Apply suggestions from code review Co-authored-by: Abhinav Gupta <[email protected]> Co-authored-by: Abhinav Gupta <[email protected]>
1 parent b98765e commit 5566339

File tree

2 files changed

+74
-20
lines changed

2 files changed

+74
-20
lines changed

app.go

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ type App struct {
300300
dones []chan os.Signal
301301
shutdownSig os.Signal
302302

303+
// Used to make sure Start/Stop is called only once.
304+
runStart sync.Once
305+
runStop sync.Once
306+
303307
osExit func(code int) // os.Exit override; used for testing only
304308
}
305309

@@ -658,21 +662,25 @@ var (
658662
// Note that Start short-circuits immediately if the New constructor
659663
// encountered any errors in application initialization.
660664
func (app *App) Start(ctx context.Context) (err error) {
661-
defer func() {
662-
app.log.LogEvent(&fxevent.Started{Err: err})
663-
}()
665+
app.runStart.Do(func() {
666+
defer func() {
667+
app.log.LogEvent(&fxevent.Started{Err: err})
668+
}()
664669

665-
if app.err != nil {
666-
// Some provides failed, short-circuit immediately.
667-
return app.err
668-
}
670+
if app.err != nil {
671+
// Some provides failed, short-circuit immediately.
672+
err = app.err
673+
return
674+
}
669675

670-
return withTimeout(ctx, &withTimeoutParams{
671-
hook: _onStartHook,
672-
callback: app.start,
673-
lifecycle: app.lifecycle,
674-
log: app.log,
676+
err = withTimeout(ctx, &withTimeoutParams{
677+
hook: _onStartHook,
678+
callback: app.start,
679+
lifecycle: app.lifecycle,
680+
log: app.log,
681+
})
675682
})
683+
return
676684
}
677685

678686
func (app *App) start(ctx context.Context) error {
@@ -700,16 +708,20 @@ func (app *App) start(ctx context.Context) error {
700708
// called are executed. However, all those hooks are executed, even if some
701709
// fail.
702710
func (app *App) Stop(ctx context.Context) (err error) {
703-
defer func() {
704-
app.log.LogEvent(&fxevent.Stopped{Err: err})
705-
}()
711+
app.runStop.Do(func() {
712+
// Protect the Stop hooks from being called multiple times.
713+
defer func() {
714+
app.log.LogEvent(&fxevent.Stopped{Err: err})
715+
}()
706716

707-
return withTimeout(ctx, &withTimeoutParams{
708-
hook: _onStopHook,
709-
callback: app.lifecycle.Stop,
710-
lifecycle: app.lifecycle,
711-
log: app.log,
717+
err = withTimeout(ctx, &withTimeoutParams{
718+
hook: _onStopHook,
719+
callback: app.lifecycle.Stop,
720+
lifecycle: app.lifecycle,
721+
log: app.log,
722+
})
712723
})
724+
return
713725
}
714726

715727
// Done returns a channel of signals to block on after starting the

app_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"reflect"
3232
"runtime"
3333
"strings"
34+
"sync"
3435
"testing"
3536
"time"
3637

@@ -1281,6 +1282,47 @@ func TestAppStart(t *testing.T) {
12811282
err := app.Start(context.Background()).Error()
12821283
assert.Contains(t, err, "OnStart hook added by go.uber.org/fx_test.TestAppStart.func10.1 failed: goroutine exited without returning")
12831284
})
1285+
1286+
t.Run("Start/Stop should be called exactly once only.", func(t *testing.T) {
1287+
t.Parallel()
1288+
startCalled := 0
1289+
stopCalled := 0
1290+
app := fxtest.New(t,
1291+
Provide(Annotate(func() int { return 0 },
1292+
OnStart(func(context.Context) error {
1293+
startCalled += 1
1294+
return nil
1295+
}),
1296+
OnStop(func(context.Context) error {
1297+
stopCalled += 1
1298+
return nil
1299+
})),
1300+
),
1301+
Invoke(func(i int) {
1302+
assert.Equal(t, 0, i)
1303+
}),
1304+
)
1305+
var wg sync.WaitGroup
1306+
for i := 0; i < 10; i++ {
1307+
wg.Add(1)
1308+
go func() {
1309+
defer wg.Done()
1310+
app.Start(context.Background())
1311+
}()
1312+
}
1313+
wg.Wait()
1314+
assert.Equal(t, 1, startCalled)
1315+
for i := 0; i < 10; i++ {
1316+
wg.Add(1)
1317+
go func() {
1318+
defer wg.Done()
1319+
app.Stop(context.Background())
1320+
}()
1321+
}
1322+
wg.Wait()
1323+
assert.Equal(t, 1, stopCalled)
1324+
})
1325+
12841326
}
12851327

12861328
func TestAppStop(t *testing.T) {

0 commit comments

Comments
 (0)