Skip to content

Commit b03b02e

Browse files
authored
Add ResumeShard method and ShardState struct to manage shard session and sequence (#460)
1 parent ac1984f commit b03b02e

File tree

6 files changed

+92
-41
lines changed

6 files changed

+92
-41
lines changed

bot/event_manager.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ type EventManager interface {
3535
RemoveEventListeners(eventListeners ...EventListener)
3636

3737
// HandleGatewayEvent calls the correct GatewayEventHandler for the payload
38-
HandleGatewayEvent(gatewayEventType gateway.EventType, sequenceNumber int, shardID int, event gateway.EventData)
38+
HandleGatewayEvent(gateway gateway.Gateway, eventType gateway.EventType, sequenceNumber int, event gateway.EventData)
3939

4040
// HandleHTTPEvent calls the HTTPServerEventHandler for the payload
4141
HandleHTTPEvent(respondFunc httpserver.RespondFunc, event httpserver.EventInteractionCreate)
@@ -91,7 +91,7 @@ type GatewayEventHandler interface {
9191
HandleGatewayEvent(client *Client, sequenceNumber int, shardID int, event gateway.EventData)
9292
}
9393

94-
// NewGatewayEventHandler returns a new GatewayEventHandler for the given GatewayEventType and handler func
94+
// NewGatewayEventHandler returns a new GatewayEventHandler for the given gateway.EventType and handler func
9595
func NewGatewayEventHandler[T gateway.EventData](eventType gateway.EventType, handleFunc func(client *Client, sequenceNumber int, shardID int, event T)) GatewayEventHandler {
9696
return &genericGatewayEventHandler[T]{eventType: eventType, handleFunc: handleFunc}
9797
}
@@ -128,13 +128,13 @@ type eventManagerImpl struct {
128128
httpServerHandler HTTPServerEventHandler
129129
}
130130

131-
func (e *eventManagerImpl) HandleGatewayEvent(gatewayEventType gateway.EventType, sequenceNumber int, shardID int, event gateway.EventData) {
131+
func (e *eventManagerImpl) HandleGatewayEvent(gateway gateway.Gateway, eventType gateway.EventType, sequenceNumber int, event gateway.EventData) {
132132
e.mu.Lock()
133133
defer e.mu.Unlock()
134-
if handler, ok := e.gatewayHandlers[gatewayEventType]; ok {
135-
handler.HandleGatewayEvent(e.client, sequenceNumber, shardID, event)
134+
if handler, ok := e.gatewayHandlers[eventType]; ok {
135+
handler.HandleGatewayEvent(e.client, sequenceNumber, gateway.ShardID(), event)
136136
} else {
137-
e.logger.Warn("no handler for Gateway event found", slog.Any("event_type", gatewayEventType))
137+
e.logger.Warn("no handler for Gateway event found", slog.Any("event_type", eventType))
138138
}
139139
}
140140

gateway/gateway.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ const (
8989

9090
type (
9191
// EventHandlerFunc is a function that is called when an event is received.
92-
EventHandlerFunc func(gatewayEventType EventType, sequenceNumber int, shardID int, event EventData)
92+
EventHandlerFunc func(gateway Gateway, eventType EventType, sequenceNumber int, event EventData)
9393

9494
// CreateFunc is a type that is used to create a new Gateway(s).
9595
CreateFunc func(token string, eventHandlerFunc EventHandlerFunc, closeHandlerFUnc CloseHandlerFunc, opts ...ConfigOpt) Gateway
@@ -579,7 +579,7 @@ loop:
579579

580580
// push message to the command manager
581581
if g.config.EnableRawEvents {
582-
g.eventHandlerFunc(EventTypeRaw, message.S, g.config.ShardID, EventRaw{
582+
g.eventHandlerFunc(g, EventTypeRaw, message.S, EventRaw{
583583
EventType: message.T,
584584
Payload: bytes.NewReader(message.RawD),
585585
})
@@ -589,7 +589,7 @@ loop:
589589
g.config.Logger.Debug("unknown event received", slog.String("event", string(message.T)), slog.String("data", string(unknownEvent)))
590590
continue
591591
}
592-
g.eventHandlerFunc(message.T, message.S, g.config.ShardID, eventData)
592+
g.eventHandlerFunc(g, message.T, message.S, eventData)
593593

594594
case OpcodeHeartbeat:
595595
g.sendHeartbeat()
@@ -622,7 +622,7 @@ loop:
622622

623623
case OpcodeHeartbeatACK:
624624
newHeartbeat := time.Now().UTC()
625-
g.eventHandlerFunc(EventTypeHeartbeatAck, message.S, g.config.ShardID, EventHeartbeatAck{
625+
g.eventHandlerFunc(g, EventTypeHeartbeatAck, message.S, EventHeartbeatAck{
626626
LastHeartbeat: g.lastHeartbeatReceived,
627627
NewHeartbeat: newHeartbeat,
628628
})

sharding/shard_manager.go

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"fmt"
77
"iter"
88
"log/slog"
9+
"maps"
10+
"slices"
911
"sync"
1012

1113
"github.com/disgoorg/snowflake/v2"
@@ -29,6 +31,9 @@ type ShardManager interface {
2931
// OpenShard opens a specific shard.
3032
OpenShard(ctx context.Context, shardID int) error
3133

34+
// ResumeShard resumes a specific shard with the given sessionID and sequence.
35+
ResumeShard(ctx context.Context, shardID int, sessionID string, sequence int) error
36+
3237
// CloseShard closes a specific shard.
3338
CloseShard(ctx context.Context, shardID int)
3439

@@ -81,10 +86,8 @@ func (m *shardManagerImpl) closeHandler(shard gateway.Gateway, err error, _ bool
8186
shard.Close(context.TODO())
8287

8388
m.shardsMu.Lock()
84-
defer m.shardsMu.Unlock()
85-
8689
delete(m.shards, shard.ShardID())
87-
delete(m.config.ShardIDs, shard.ShardID())
90+
defer m.shardsMu.Unlock()
8891

8992
newShardCount := shard.ShardCount() * m.config.ShardSplitCount
9093
if newShardCount > m.config.ShardCount {
@@ -99,20 +102,25 @@ func (m *shardManagerImpl) closeHandler(shard gateway.Gateway, err error, _ bool
99102
}
100103

101104
var wg sync.WaitGroup
102-
for i := range newShardIDs {
103-
shardID := newShardIDs[i]
105+
for _, shardID := range newShardIDs {
104106
wg.Add(1)
107+
105108
go func() {
106109
defer wg.Done()
107-
if err := m.config.RateLimiter.WaitBucket(context.TODO(), shardID); err != nil {
110+
111+
if err := m.config.RateLimiter.WaitBucket(context.Background(), shardID); err != nil {
108112
m.config.Logger.Error("failed to wait shard bucket", slog.Any("err", err), slog.Int("shard_id", shardID))
109113
return
110114
}
111115
defer m.config.RateLimiter.UnlockBucket(shardID)
112116

113117
newShard := m.config.GatewayCreateFunc(m.token, m.eventHandlerFunc, m.closeHandler, append(m.config.GatewayConfigOpts, gateway.WithShardID(shardID), gateway.WithShardCount(newShardCount))...)
118+
119+
m.shardsMu.Lock()
114120
m.shards[shardID] = newShard
115-
if err := newShard.Open(context.TODO()); err != nil {
121+
m.shardsMu.Unlock()
122+
123+
if err := newShard.Open(context.Background()); err != nil {
116124
m.config.Logger.Error("failed to re shard", slog.Any("err", err), slog.Int("shard_id", shardID))
117125
}
118126
}()
@@ -122,27 +130,41 @@ func (m *shardManagerImpl) closeHandler(shard gateway.Gateway, err error, _ bool
122130
}
123131

124132
func (m *shardManagerImpl) Open(ctx context.Context) {
125-
m.config.Logger.Debug("opening shards", slog.String("shard_ids", fmt.Sprint(m.config.ShardIDs)))
126-
var wg sync.WaitGroup
133+
m.config.Logger.Debug("opening shards", slog.String("shard_ids", fmt.Sprint(slices.Collect(maps.Keys(m.config.ShardIDs)))))
127134

128-
m.shardsMu.Lock()
129-
defer m.shardsMu.Unlock()
130-
for shardID := range m.config.ShardIDs {
131-
if _, ok := m.shards[shardID]; ok {
135+
var wg sync.WaitGroup
136+
for shardID, shardState := range m.config.ShardIDs {
137+
m.shardsMu.Lock()
138+
_, ok := m.shards[shardID]
139+
m.shardsMu.Unlock()
140+
if ok {
132141
continue
133142
}
134143

135144
wg.Add(1)
136145
go func() {
137146
defer wg.Done()
147+
138148
if err := m.config.RateLimiter.WaitBucket(ctx, shardID); err != nil {
139149
m.config.Logger.Error("failed to wait shard bucket", slog.Any("err", err), slog.Int("shard_id", shardID))
140150
return
141151
}
142152
defer m.config.RateLimiter.UnlockBucket(shardID)
143153

144-
shard := m.config.GatewayCreateFunc(m.token, m.eventHandlerFunc, m.closeHandler, append(m.config.GatewayConfigOpts, gateway.WithShardID(shardID), gateway.WithShardCount(m.config.ShardCount))...)
154+
opts := append(m.config.GatewayConfigOpts, gateway.WithShardID(shardID), gateway.WithShardCount(m.config.ShardCount))
155+
if shardState.SessionID != "" {
156+
opts = append(opts, gateway.WithSessionID(shardState.SessionID))
157+
}
158+
if shardState.Sequence != 0 {
159+
opts = append(opts, gateway.WithSequence(shardState.Sequence))
160+
}
161+
162+
shard := m.config.GatewayCreateFunc(m.token, m.eventHandlerFunc, m.closeHandler, opts...)
163+
164+
m.shardsMu.Lock()
145165
m.shards[shardID] = shard
166+
m.shardsMu.Unlock()
167+
146168
if err := shard.Open(ctx); err != nil {
147169
m.config.Logger.Error("failed to open shard", slog.Any("err", err), slog.Int("shard_id", shardID))
148170
}
@@ -152,40 +174,52 @@ func (m *shardManagerImpl) Open(ctx context.Context) {
152174
}
153175

154176
func (m *shardManagerImpl) Close(ctx context.Context) {
155-
m.config.Logger.Debug("closing shards", slog.String("shard_ids", fmt.Sprint(m.config.ShardIDs)))
177+
m.config.Logger.Debug("closing shards", slog.String("shard_ids", fmt.Sprint(slices.Collect(maps.Keys(m.shards)))))
156178
var wg sync.WaitGroup
157179

158180
m.shardsMu.Lock()
159181
defer m.shardsMu.Unlock()
160-
for shardID := range m.shards {
161-
shard := m.shards[shardID]
162-
delete(m.shards, shardID)
182+
for _, shard := range m.shards {
163183
wg.Add(1)
164184
go func() {
165185
defer wg.Done()
166186
shard.Close(ctx)
167187
}()
168188
}
169189
wg.Wait()
190+
m.shards = map[int]gateway.Gateway{}
170191
}
171192

172193
func (m *shardManagerImpl) OpenShard(ctx context.Context, shardID int) error {
173-
return m.openShard(ctx, shardID, m.config.ShardCount)
194+
return m.openShard(ctx, shardID, m.config.ShardCount, "", 0)
195+
}
196+
197+
func (m *shardManagerImpl) ResumeShard(ctx context.Context, shardID int, sessionID string, sequence int) error {
198+
return m.openShard(ctx, shardID, m.config.ShardCount, sessionID, sequence)
174199
}
175200

176-
func (m *shardManagerImpl) openShard(ctx context.Context, shardID int, shardCount int) error {
177-
m.config.Logger.Debug("opening shard", slog.Int("shard_id", shardID))
201+
func (m *shardManagerImpl) openShard(ctx context.Context, shardID int, shardCount int, sessionID string, sequence int) error {
202+
m.config.Logger.Debug("opening shard", slog.Int("shard_id", shardID), slog.Int("shard_count", shardCount), slog.String("session_id", sessionID), slog.Int("sequence", sequence))
178203

179204
if err := m.config.RateLimiter.WaitBucket(ctx, shardID); err != nil {
180205
return err
181206
}
182207
defer m.config.RateLimiter.UnlockBucket(shardID)
183-
shard := m.config.GatewayCreateFunc(m.token, m.eventHandlerFunc, m.closeHandler, append(m.config.GatewayConfigOpts, gateway.WithShardID(shardID), gateway.WithShardCount(shardCount))...)
208+
209+
opts := append(m.config.GatewayConfigOpts, gateway.WithShardID(shardID), gateway.WithShardCount(shardCount))
210+
if sessionID != "" {
211+
opts = append(opts, gateway.WithSessionID(sessionID))
212+
}
213+
if sequence != 0 {
214+
opts = append(opts, gateway.WithSequence(sequence))
215+
}
216+
217+
shard := m.config.GatewayCreateFunc(m.token, m.eventHandlerFunc, m.closeHandler, opts...)
184218

185219
m.shardsMu.Lock()
186-
defer m.shardsMu.Unlock()
187-
m.config.ShardIDs[shardID] = struct{}{}
188220
m.shards[shardID] = shard
221+
defer m.shardsMu.Unlock()
222+
189223
return shard.Open(ctx)
190224
}
191225

@@ -220,6 +254,7 @@ func (m *shardManagerImpl) Shards() iter.Seq[gateway.Gateway] {
220254
return func(yield func(gateway.Gateway) bool) {
221255
m.shardsMu.Lock()
222256
defer m.shardsMu.Unlock()
257+
223258
for _, shard := range m.shards {
224259
if !yield(shard) {
225260
return

sharding/shard_manager_config.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,20 @@ func defaultConfig() config {
1414
}
1515
}
1616

17+
// ShardState is used to tell a [gateway.Gateway] managed by the [ShardManager] which session & sequence it should use when starting the shard.
18+
// This is useful for resuming shards when using the [ShardManager].
19+
type ShardState struct {
20+
// SessionID is the session ID of the shard. This is used to resume the shard.
21+
SessionID string
22+
// Sequence is the sequence number of the shard. This is used to resume the shard.
23+
Sequence int
24+
}
25+
1726
type config struct {
1827
// Logger is the logger of the ShardManager. Defaults to log.Default()
1928
Logger *slog.Logger
2029
// ShardIDs is a map of shardIDs the ShardManager should manage. Leave this nil to manage all shards.
21-
ShardIDs map[int]struct{}
30+
ShardIDs map[int]ShardState
2231
// ShardCount is the total shard count of the ShardManager. Leave this at 0 to let Discord calculate the shard count for you.
2332
ShardCount int
2433
// ShardSplitCount is the count a shard should be split into if it is too large. This is only used if AutoScaling is enabled.
@@ -63,13 +72,20 @@ func WithLogger(logger *slog.Logger) ConfigOpt {
6372
// WithShardIDs sets the shardIDs the ShardManager should manage.
6473
func WithShardIDs(shardIDs ...int) ConfigOpt {
6574
return func(config *config) {
66-
config.ShardIDs = map[int]struct{}{}
75+
config.ShardIDs = map[int]ShardState{}
6776
for _, shardID := range shardIDs {
68-
config.ShardIDs[shardID] = struct{}{}
77+
config.ShardIDs[shardID] = ShardState{}
6978
}
7079
}
7180
}
7281

82+
// WithShardIDsWithStates sets the shardIDs and their [ShardState] the ShardManager should manage.
83+
func WithShardIDsWithStates(shards map[int]ShardState) ConfigOpt {
84+
return func(config *config) {
85+
config.ShardIDs = shards
86+
}
87+
}
88+
7389
// WithShardCount sets the shard count of the ShardManager.
7490
func WithShardCount(shardCount int) ConfigOpt {
7591
return func(config *config) {

voice/conn.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ func (c *connImpl) HandleVoiceServerUpdate(update botgateway.EventVoiceServerUpd
194194
}()
195195
}
196196

197-
func (c *connImpl) handleMessage(op Opcode, data GatewayMessageData) {
197+
func (c *connImpl) handleMessage(gateway Gateway, op Opcode, data GatewayMessageData) {
198198
switch d := data.(type) {
199199
case GatewayMessageDataReady:
200200
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -238,7 +238,7 @@ func (c *connImpl) handleMessage(op Opcode, data GatewayMessageData) {
238238
}
239239
}
240240
if c.config.EventHandlerFunc != nil {
241-
c.config.EventHandlerFunc(op, data)
241+
c.config.EventHandlerFunc(c.gateway, op, data)
242242
}
243243
}
244244

voice/gateway.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ const (
4646

4747
type (
4848
// EventHandlerFunc is a function that handles a voice gateway event.
49-
EventHandlerFunc func(opCode Opcode, data GatewayMessageData)
49+
EventHandlerFunc func(gateway Gateway, opCode Opcode, data GatewayMessageData)
5050

5151
// CloseHandlerFunc is a function that handles a voice gateway close.
5252
CloseHandlerFunc func(gateway Gateway, err error, reconnect bool)
@@ -313,7 +313,7 @@ loop:
313313
}
314314
g.lastHeartbeatReceived = time.Now().UTC()
315315
}
316-
g.eventHandlerFunc(message.Op, message.D)
316+
g.eventHandlerFunc(g, message.Op, message.D)
317317
}
318318
}
319319

0 commit comments

Comments
 (0)