6
6
"fmt"
7
7
"iter"
8
8
"log/slog"
9
+ "maps"
10
+ "slices"
9
11
"sync"
10
12
11
13
"github.com/disgoorg/snowflake/v2"
@@ -29,6 +31,9 @@ type ShardManager interface {
29
31
// OpenShard opens a specific shard.
30
32
OpenShard (ctx context.Context , shardID int ) error
31
33
34
+ // ResumeShard resumes a specific shard with the given sessionID and sequence.
35
+ ResumeShard (ctx context.Context , shardID int , sessionID string , sequence int ) error
36
+
32
37
// CloseShard closes a specific shard.
33
38
CloseShard (ctx context.Context , shardID int )
34
39
@@ -81,10 +86,8 @@ func (m *shardManagerImpl) closeHandler(shard gateway.Gateway, err error, _ bool
81
86
shard .Close (context .TODO ())
82
87
83
88
m .shardsMu .Lock ()
84
- defer m .shardsMu .Unlock ()
85
-
86
89
delete (m .shards , shard .ShardID ())
87
- delete ( m . config . ShardIDs , shard . ShardID () )
90
+ defer m . shardsMu . Unlock ( )
88
91
89
92
newShardCount := shard .ShardCount () * m .config .ShardSplitCount
90
93
if newShardCount > m .config .ShardCount {
@@ -99,20 +102,25 @@ func (m *shardManagerImpl) closeHandler(shard gateway.Gateway, err error, _ bool
99
102
}
100
103
101
104
var wg sync.WaitGroup
102
- for i := range newShardIDs {
103
- shardID := newShardIDs [i ]
105
+ for _ , shardID := range newShardIDs {
104
106
wg .Add (1 )
107
+
105
108
go func () {
106
109
defer wg .Done ()
107
- if err := m .config .RateLimiter .WaitBucket (context .TODO (), shardID ); err != nil {
110
+
111
+ if err := m .config .RateLimiter .WaitBucket (context .Background (), shardID ); err != nil {
108
112
m .config .Logger .Error ("failed to wait shard bucket" , slog .Any ("err" , err ), slog .Int ("shard_id" , shardID ))
109
113
return
110
114
}
111
115
defer m .config .RateLimiter .UnlockBucket (shardID )
112
116
113
117
newShard := m .config .GatewayCreateFunc (m .token , m .eventHandlerFunc , m .closeHandler , append (m .config .GatewayConfigOpts , gateway .WithShardID (shardID ), gateway .WithShardCount (newShardCount ))... )
118
+
119
+ m .shardsMu .Lock ()
114
120
m .shards [shardID ] = newShard
115
- if err := newShard .Open (context .TODO ()); err != nil {
121
+ m .shardsMu .Unlock ()
122
+
123
+ if err := newShard .Open (context .Background ()); err != nil {
116
124
m .config .Logger .Error ("failed to re shard" , slog .Any ("err" , err ), slog .Int ("shard_id" , shardID ))
117
125
}
118
126
}()
@@ -122,27 +130,41 @@ func (m *shardManagerImpl) closeHandler(shard gateway.Gateway, err error, _ bool
122
130
}
123
131
124
132
func (m * shardManagerImpl ) Open (ctx context.Context ) {
125
- m .config .Logger .Debug ("opening shards" , slog .String ("shard_ids" , fmt .Sprint (m .config .ShardIDs )))
126
- var wg sync.WaitGroup
133
+ m .config .Logger .Debug ("opening shards" , slog .String ("shard_ids" , fmt .Sprint (slices .Collect (maps .Keys (m .config .ShardIDs )))))
127
134
128
- m .shardsMu .Lock ()
129
- defer m .shardsMu .Unlock ()
130
- for shardID := range m .config .ShardIDs {
131
- if _ , ok := m .shards [shardID ]; ok {
135
+ var wg sync.WaitGroup
136
+ for shardID , shardState := range m .config .ShardIDs {
137
+ m .shardsMu .Lock ()
138
+ _ , ok := m .shards [shardID ]
139
+ m .shardsMu .Unlock ()
140
+ if ok {
132
141
continue
133
142
}
134
143
135
144
wg .Add (1 )
136
145
go func () {
137
146
defer wg .Done ()
147
+
138
148
if err := m .config .RateLimiter .WaitBucket (ctx , shardID ); err != nil {
139
149
m .config .Logger .Error ("failed to wait shard bucket" , slog .Any ("err" , err ), slog .Int ("shard_id" , shardID ))
140
150
return
141
151
}
142
152
defer m .config .RateLimiter .UnlockBucket (shardID )
143
153
144
- shard := m .config .GatewayCreateFunc (m .token , m .eventHandlerFunc , m .closeHandler , append (m .config .GatewayConfigOpts , gateway .WithShardID (shardID ), gateway .WithShardCount (m .config .ShardCount ))... )
154
+ opts := append (m .config .GatewayConfigOpts , gateway .WithShardID (shardID ), gateway .WithShardCount (m .config .ShardCount ))
155
+ if shardState .SessionID != "" {
156
+ opts = append (opts , gateway .WithSessionID (shardState .SessionID ))
157
+ }
158
+ if shardState .Sequence != 0 {
159
+ opts = append (opts , gateway .WithSequence (shardState .Sequence ))
160
+ }
161
+
162
+ shard := m .config .GatewayCreateFunc (m .token , m .eventHandlerFunc , m .closeHandler , opts ... )
163
+
164
+ m .shardsMu .Lock ()
145
165
m .shards [shardID ] = shard
166
+ m .shardsMu .Unlock ()
167
+
146
168
if err := shard .Open (ctx ); err != nil {
147
169
m .config .Logger .Error ("failed to open shard" , slog .Any ("err" , err ), slog .Int ("shard_id" , shardID ))
148
170
}
@@ -152,40 +174,52 @@ func (m *shardManagerImpl) Open(ctx context.Context) {
152
174
}
153
175
154
176
func (m * shardManagerImpl ) Close (ctx context.Context ) {
155
- m .config .Logger .Debug ("closing shards" , slog .String ("shard_ids" , fmt .Sprint (m . config . ShardIDs )))
177
+ m .config .Logger .Debug ("closing shards" , slog .String ("shard_ids" , fmt .Sprint (slices . Collect ( maps . Keys ( m . shards )) )))
156
178
var wg sync.WaitGroup
157
179
158
180
m .shardsMu .Lock ()
159
181
defer m .shardsMu .Unlock ()
160
- for shardID := range m .shards {
161
- shard := m .shards [shardID ]
162
- delete (m .shards , shardID )
182
+ for _ , shard := range m .shards {
163
183
wg .Add (1 )
164
184
go func () {
165
185
defer wg .Done ()
166
186
shard .Close (ctx )
167
187
}()
168
188
}
169
189
wg .Wait ()
190
+ m .shards = map [int ]gateway.Gateway {}
170
191
}
171
192
172
193
func (m * shardManagerImpl ) OpenShard (ctx context.Context , shardID int ) error {
173
- return m .openShard (ctx , shardID , m .config .ShardCount )
194
+ return m .openShard (ctx , shardID , m .config .ShardCount , "" , 0 )
195
+ }
196
+
197
+ func (m * shardManagerImpl ) ResumeShard (ctx context.Context , shardID int , sessionID string , sequence int ) error {
198
+ return m .openShard (ctx , shardID , m .config .ShardCount , sessionID , sequence )
174
199
}
175
200
176
- func (m * shardManagerImpl ) openShard (ctx context.Context , shardID int , shardCount int ) error {
177
- m .config .Logger .Debug ("opening shard" , slog .Int ("shard_id" , shardID ))
201
+ func (m * shardManagerImpl ) openShard (ctx context.Context , shardID int , shardCount int , sessionID string , sequence int ) error {
202
+ m .config .Logger .Debug ("opening shard" , slog .Int ("shard_id" , shardID ), slog . Int ( "shard_count" , shardCount ), slog . String ( "session_id" , sessionID ), slog . Int ( "sequence" , sequence ) )
178
203
179
204
if err := m .config .RateLimiter .WaitBucket (ctx , shardID ); err != nil {
180
205
return err
181
206
}
182
207
defer m .config .RateLimiter .UnlockBucket (shardID )
183
- shard := m .config .GatewayCreateFunc (m .token , m .eventHandlerFunc , m .closeHandler , append (m .config .GatewayConfigOpts , gateway .WithShardID (shardID ), gateway .WithShardCount (shardCount ))... )
208
+
209
+ opts := append (m .config .GatewayConfigOpts , gateway .WithShardID (shardID ), gateway .WithShardCount (shardCount ))
210
+ if sessionID != "" {
211
+ opts = append (opts , gateway .WithSessionID (sessionID ))
212
+ }
213
+ if sequence != 0 {
214
+ opts = append (opts , gateway .WithSequence (sequence ))
215
+ }
216
+
217
+ shard := m .config .GatewayCreateFunc (m .token , m .eventHandlerFunc , m .closeHandler , opts ... )
184
218
185
219
m .shardsMu .Lock ()
186
- defer m .shardsMu .Unlock ()
187
- m .config .ShardIDs [shardID ] = struct {}{}
188
220
m .shards [shardID ] = shard
221
+ defer m .shardsMu .Unlock ()
222
+
189
223
return shard .Open (ctx )
190
224
}
191
225
@@ -220,6 +254,7 @@ func (m *shardManagerImpl) Shards() iter.Seq[gateway.Gateway] {
220
254
return func (yield func (gateway.Gateway ) bool ) {
221
255
m .shardsMu .Lock ()
222
256
defer m .shardsMu .Unlock ()
257
+
223
258
for _ , shard := range m .shards {
224
259
if ! yield (shard ) {
225
260
return
0 commit comments