Skip to content

Commit 800b2ea

Browse files
sywhangabhinav
andcommitted
Ensure OnStart/OnStop hooks can only be called once (uber-go#931)
* Ensure OnStop hooks can only be called once It is possible for a user to erroneously call app.Run() then follow up by calling app.Stop(). In such a case, it is possible for the app.Stop() method to be called by two goroutines concurrently, resulting in a race. This adds a state in the App to keep track of whether Stop() has been invoked so that such a race can be prevented. Fix uber-go#930 Internal Ref: GO-1606 * use sync.Once and also add the check for OnStart hooks * Apply suggestions from code review Co-authored-by: Abhinav Gupta <[email protected]> Co-authored-by: Abhinav Gupta <[email protected]>
1 parent 40a4624 commit 800b2ea

File tree

2 files changed

+74
-20
lines changed

2 files changed

+74
-20
lines changed

app.go

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ type App struct {
300300
dones []chan os.Signal
301301
shutdownSig os.Signal
302302

303+
// Used to make sure Start/Stop is called only once.
304+
runStart sync.Once
305+
runStop sync.Once
306+
303307
osExit func(code int) // os.Exit override; used for testing only
304308
}
305309

@@ -661,21 +665,25 @@ var (
661665
// Note that Start short-circuits immediately if the New constructor
662666
// encountered any errors in application initialization.
663667
func (app *App) Start(ctx context.Context) (err error) {
664-
defer func() {
665-
app.log.LogEvent(&fxevent.Started{Err: err})
666-
}()
668+
app.runStart.Do(func() {
669+
defer func() {
670+
app.log.LogEvent(&fxevent.Started{Err: err})
671+
}()
667672

668-
if app.err != nil {
669-
// Some provides failed, short-circuit immediately.
670-
return app.err
671-
}
673+
if app.err != nil {
674+
// Some provides failed, short-circuit immediately.
675+
err = app.err
676+
return
677+
}
672678

673-
return withTimeout(ctx, &withTimeoutParams{
674-
hook: _onStartHook,
675-
callback: app.start,
676-
lifecycle: app.lifecycle,
677-
log: app.log,
679+
err = withTimeout(ctx, &withTimeoutParams{
680+
hook: _onStartHook,
681+
callback: app.start,
682+
lifecycle: app.lifecycle,
683+
log: app.log,
684+
})
678685
})
686+
return
679687
}
680688

681689
func (app *App) start(ctx context.Context) error {
@@ -703,16 +711,20 @@ func (app *App) start(ctx context.Context) error {
703711
// called are executed. However, all those hooks are executed, even if some
704712
// fail.
705713
func (app *App) Stop(ctx context.Context) (err error) {
706-
defer func() {
707-
app.log.LogEvent(&fxevent.Stopped{Err: err})
708-
}()
714+
app.runStop.Do(func() {
715+
// Protect the Stop hooks from being called multiple times.
716+
defer func() {
717+
app.log.LogEvent(&fxevent.Stopped{Err: err})
718+
}()
709719

710-
return withTimeout(ctx, &withTimeoutParams{
711-
hook: _onStopHook,
712-
callback: app.lifecycle.Stop,
713-
lifecycle: app.lifecycle,
714-
log: app.log,
720+
err = withTimeout(ctx, &withTimeoutParams{
721+
hook: _onStopHook,
722+
callback: app.lifecycle.Stop,
723+
lifecycle: app.lifecycle,
724+
log: app.log,
725+
})
715726
})
727+
return
716728
}
717729

718730
// Done returns a channel of signals to block on after starting the

app_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"reflect"
3232
"runtime"
3333
"strings"
34+
"sync"
3435
"testing"
3536
"time"
3637

@@ -1281,6 +1282,47 @@ func TestAppStart(t *testing.T) {
12811282
err := app.Start(context.Background()).Error()
12821283
assert.Contains(t, err, "OnStart hook added by go.uber.org/fx_test.TestAppStart.func10.1 failed: goroutine exited without returning")
12831284
})
1285+
1286+
t.Run("Start/Stop should be called exactly once only.", func(t *testing.T) {
1287+
t.Parallel()
1288+
startCalled := 0
1289+
stopCalled := 0
1290+
app := fxtest.New(t,
1291+
Provide(Annotate(func() int { return 0 },
1292+
OnStart(func(context.Context) error {
1293+
startCalled += 1
1294+
return nil
1295+
}),
1296+
OnStop(func(context.Context) error {
1297+
stopCalled += 1
1298+
return nil
1299+
})),
1300+
),
1301+
Invoke(func(i int) {
1302+
assert.Equal(t, 0, i)
1303+
}),
1304+
)
1305+
var wg sync.WaitGroup
1306+
for i := 0; i < 10; i++ {
1307+
wg.Add(1)
1308+
go func() {
1309+
defer wg.Done()
1310+
app.Start(context.Background())
1311+
}()
1312+
}
1313+
wg.Wait()
1314+
assert.Equal(t, 1, startCalled)
1315+
for i := 0; i < 10; i++ {
1316+
wg.Add(1)
1317+
go func() {
1318+
defer wg.Done()
1319+
app.Stop(context.Background())
1320+
}()
1321+
}
1322+
wg.Wait()
1323+
assert.Equal(t, 1, stopCalled)
1324+
})
1325+
12841326
}
12851327

12861328
func TestAppStop(t *testing.T) {

0 commit comments

Comments
 (0)