@@ -9,13 +9,16 @@ import (
9
9
"log/slog"
10
10
"math/rand/v2"
11
11
"net"
12
+ "net/url"
13
+ "strconv"
12
14
"sync"
13
15
"syscall"
14
16
"time"
15
17
16
- "github.com/disgoorg/disgo/discord"
17
18
"github.com/disgoorg/json/v2"
18
19
"github.com/gorilla/websocket"
20
+
21
+ "github.com/disgoorg/disgo/discord"
19
22
)
20
23
21
24
// Version defines which discord API version disgo should use to connect to discord.
@@ -169,8 +172,8 @@ type gatewayImpl struct {
169
172
closeHandlerFunc CloseHandlerFunc
170
173
token string
171
174
172
- transport transport
173
- transportMu sync.Mutex
175
+ conn transport
176
+ connMux sync.Mutex
174
177
heartbeatCancel context.CancelFunc
175
178
status Status
176
179
statusMu sync.Mutex
@@ -211,9 +214,9 @@ func (g *gatewayImpl) Open(ctx context.Context) error {
211
214
func (g * gatewayImpl ) open (ctx context.Context ) error {
212
215
g .config .Logger .DebugContext (ctx , "opening gateway connection" )
213
216
214
- g .transportMu .Lock ()
215
- if g .transport != nil {
216
- g .transportMu .Unlock ()
217
+ g .connMux .Lock ()
218
+ if g .conn != nil {
219
+ g .connMux .Unlock ()
217
220
return discord .ErrGatewayAlreadyConnected
218
221
}
219
222
g .statusMu .Lock ()
@@ -223,7 +226,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {
223
226
if g .config .LastSequenceReceived == nil || g .config .SessionID == nil {
224
227
if err := g .config .IdentifyRateLimiter .Wait (ctx , g .config .ShardID ); err != nil {
225
228
g .config .Logger .ErrorContext (ctx , "failed to wait for identify rate limiter" , slog .Any ("err" , err ))
226
- g .transportMu .Unlock ()
229
+ g .connMux .Unlock ()
227
230
return fmt .Errorf ("failed to wait for identify rate limiter: %w" , err )
228
231
}
229
232
defer g .config .IdentifyRateLimiter .Unlock (g .config .ShardID )
@@ -233,15 +236,17 @@ func (g *gatewayImpl) open(ctx context.Context) error {
233
236
if g .config .ResumeURL != nil && g .config .EnableResumeURL {
234
237
wsURL = * g .config .ResumeURL
235
238
}
236
- gatewayURL := fmt .Sprintf ("%s?v=%d&encoding=json" , wsURL , Version )
237
239
238
- switch g .config .Compression {
239
- case ZlibStreamCompression :
240
- gatewayURL += "&compress=zlib-stream"
241
- case ZstdStreamCompression :
242
- gatewayURL += "&compress=zstd-stream"
240
+ values := url.Values {}
241
+ values .Set ("v" , strconv .Itoa (Version ))
242
+ values .Set ("encoding" , "json" )
243
+
244
+ if g .config .Compression .isStreamCompression () {
245
+ values .Set ("compress" , string (g .config .Compression ))
243
246
}
244
247
248
+ gatewayURL := wsURL + "?" + values .Encode ()
249
+
245
250
g .lastHeartbeatSent = time .Now ().UTC ()
246
251
conn , rs , err := g .config .Dialer .DialContext (ctx , gatewayURL , nil )
247
252
if err != nil {
@@ -258,33 +263,17 @@ func (g *gatewayImpl) open(ctx context.Context) error {
258
263
}
259
264
260
265
g .config .Logger .ErrorContext (ctx , "error connecting to the gateway" , slog .Any ("err" , err ), slog .String ("url" , gatewayURL ), slog .String ("body" , body ))
261
- g .transportMu .Unlock ()
266
+ g .connMux .Unlock ()
262
267
return err
263
268
}
264
269
265
270
conn .SetCloseHandler (func (code int , text string ) error {
266
271
return nil
267
272
})
268
273
269
- var t transport
270
- switch g .config .Compression {
271
- case ZstdStreamCompression :
272
- g .config .Logger .Debug ("using zstd stream compression" )
273
- t = newZstdStreamTransport (conn , g .config .Logger )
274
- case ZlibStreamCompression :
275
- g .config .Logger .Debug ("using zlib stream compression" )
276
- t = newZlibStreamTransport (conn , g .config .Logger )
277
- default :
278
- // zlibPayloadTransport supports both compressed (using zlib)
279
- // and uncompressed payloads
280
- //
281
- // The identify payload will state whether (some) payloads
282
- // will be compressed or not
283
- g .config .Logger .Debug ("using no stream compression" )
284
- t = newZlibPayloadTransport (conn , g .config .Logger )
285
- }
286
- g .transport = t
287
- g .transportMu .Unlock ()
274
+ g .config .Logger .DebugContext (ctx , "using compression" , slog .String ("compressionType" , string (g .config .Compression )))
275
+ g .conn = g .config .Compression .newTransport (conn , g .config .Logger )
276
+ g .connMux .Unlock ()
288
277
289
278
// reset rate limiter when connecting
290
279
g .config .RateLimiter .Reset ()
@@ -295,7 +284,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {
295
284
296
285
var readyOnce sync.Once
297
286
readyChan := make (chan error )
298
- go g .listen (t , func (err error ) {
287
+ go g .listen (g . conn , func (err error ) {
299
288
readyOnce .Do (func () {
300
289
readyChan <- err
301
290
close (readyChan )
@@ -330,16 +319,16 @@ func (g *gatewayImpl) CloseWithCode(ctx context.Context, code int, message strin
330
319
g .heartbeatCancel ()
331
320
}
332
321
333
- g .transportMu .Lock ()
334
- defer g .transportMu .Unlock ()
335
- if g .transport != nil {
322
+ g .connMux .Lock ()
323
+ defer g .connMux .Unlock ()
324
+ if g .conn != nil {
336
325
g .config .RateLimiter .Close (ctx )
337
326
g .config .Logger .DebugContext (ctx , "closing gateway connection" , slog .Int ("code" , code ), slog .String ("message" , message ))
338
- if err := g .transport .WriteMessage (websocket .CloseMessage , websocket .FormatCloseMessage (code , message )); err != nil && ! errors .Is (err , websocket .ErrCloseSent ) {
327
+ if err := g .conn .WriteMessage (websocket .CloseMessage , websocket .FormatCloseMessage (code , message )); err != nil && ! errors .Is (err , websocket .ErrCloseSent ) {
339
328
g .config .Logger .DebugContext (ctx , "error writing close code" , slog .Any ("err" , err ))
340
329
}
341
- _ = g .transport .Close ()
342
- g .transport = nil
330
+ _ = g .conn .Close ()
331
+ g .conn = nil
343
332
344
333
// clear resume data as we closed gracefully
345
334
if code == websocket .CloseNormalClosure || code == websocket .CloseGoingAway {
@@ -381,9 +370,9 @@ func (g *gatewayImpl) sendInternal(ctx context.Context, op Opcode, d MessageData
381
370
}
382
371
383
372
func (g * gatewayImpl ) send (ctx context.Context , messageType int , data []byte ) error {
384
- g .transportMu .Lock ()
385
- defer g .transportMu .Unlock ()
386
- if g .transport == nil {
373
+ g .connMux .Lock ()
374
+ defer g .connMux .Unlock ()
375
+ if g .conn == nil {
387
376
return discord .ErrShardNotConnected
388
377
}
389
378
@@ -393,7 +382,7 @@ func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) er
393
382
394
383
defer g .config .RateLimiter .Unlock ()
395
384
g .config .Logger .DebugContext (ctx , "sending gateway command" , slog .String ("data" , string (data )))
396
- return g .transport .WriteMessage (messageType , data )
385
+ return g .conn .WriteMessage (messageType , data )
397
386
}
398
387
399
388
func (g * gatewayImpl ) Latency () time.Duration {
@@ -534,7 +523,7 @@ func (g *gatewayImpl) identify() error {
534
523
Browser : g .config .Browser ,
535
524
Device : g .config .Device ,
536
525
},
537
- Compress : g .config .Compression == ZlibPayloadCompression ,
526
+ Compress : g .config .Compression . isPayloadCompression () ,
538
527
LargeThreshold : g .config .LargeThreshold ,
539
528
Intents : g .config .Intents ,
540
529
Presence : g .config .Presence ,
@@ -588,9 +577,9 @@ func (g *gatewayImpl) listen(transport transport, ready func(error)) {
588
577
return
589
578
}
590
579
g .statusMu .Unlock ()
591
- g .transportMu .Lock ()
592
- sameConn := g .transport == transport
593
- g .transportMu .Unlock ()
580
+ g .connMux .Lock ()
581
+ sameConn := g .conn == transport
582
+ g .connMux .Unlock ()
594
583
595
584
// if sameConnection is false, it means the connection has been closed by the user, and we can just exit
596
585
if ! sameConn {
0 commit comments