Skip to content

Commit 88766c7

Browse files
committed
chore: pull request comments
1 parent ab27fe5 commit 88766c7

File tree

5 files changed

+78
-62
lines changed

5 files changed

+78
-62
lines changed

_examples/proxy/example.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ func main() {
5151
client, err := disgo.New(token,
5252
bot.WithShardManagerConfigOpts(
5353
sharding.WithGatewayConfigOpts( // gateway intents are set in the proxy not here
54-
gateway.WithURL(gatewayURL), // set the custom gateway url
55-
gateway.WithCompression(gateway.NoCompression), // we don't want compression as that would be additional overhead
54+
gateway.WithURL(gatewayURL), // set the custom gateway url
55+
gateway.WithCompression(gateway.CompressionNone), // we don't want compression as that would be additional overhead
5656
),
5757
sharding.WithIdentifyRateLimiter(gateway.NewNoopIdentifyRateLimiter()), // disable sharding rate limiter as the proxy handles it
5858
),

_examples/sharding/example.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func main() {
3030
sharding.WithAutoScaling(true),
3131
sharding.WithGatewayConfigOpts(
3232
gateway.WithIntents(gateway.IntentGuilds, gateway.IntentGuildMessages, gateway.IntentDirectMessages),
33-
gateway.WithCompression(gateway.ZstdStreamCompression),
33+
gateway.WithCompression(gateway.CompressionZstdStream),
3434
),
3535
),
3636
bot.WithEventListeners(&events.ListenerAdapter{

gateway/gateway.go

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ import (
99
"log/slog"
1010
"math/rand/v2"
1111
"net"
12+
"net/url"
13+
"strconv"
1214
"sync"
1315
"syscall"
1416
"time"
1517

16-
"github.com/disgoorg/disgo/discord"
1718
"github.com/disgoorg/json/v2"
1819
"github.com/gorilla/websocket"
20+
21+
"github.com/disgoorg/disgo/discord"
1922
)
2023

2124
// Version defines which discord API version disgo should use to connect to discord.
@@ -169,8 +172,8 @@ type gatewayImpl struct {
169172
closeHandlerFunc CloseHandlerFunc
170173
token string
171174

172-
transport transport
173-
transportMu sync.Mutex
175+
conn transport
176+
connMux sync.Mutex
174177
heartbeatCancel context.CancelFunc
175178
status Status
176179
statusMu sync.Mutex
@@ -211,9 +214,9 @@ func (g *gatewayImpl) Open(ctx context.Context) error {
211214
func (g *gatewayImpl) open(ctx context.Context) error {
212215
g.config.Logger.DebugContext(ctx, "opening gateway connection")
213216

214-
g.transportMu.Lock()
215-
if g.transport != nil {
216-
g.transportMu.Unlock()
217+
g.connMux.Lock()
218+
if g.conn != nil {
219+
g.connMux.Unlock()
217220
return discord.ErrGatewayAlreadyConnected
218221
}
219222
g.statusMu.Lock()
@@ -223,7 +226,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {
223226
if g.config.LastSequenceReceived == nil || g.config.SessionID == nil {
224227
if err := g.config.IdentifyRateLimiter.Wait(ctx, g.config.ShardID); err != nil {
225228
g.config.Logger.ErrorContext(ctx, "failed to wait for identify rate limiter", slog.Any("err", err))
226-
g.transportMu.Unlock()
229+
g.connMux.Unlock()
227230
return fmt.Errorf("failed to wait for identify rate limiter: %w", err)
228231
}
229232
defer g.config.IdentifyRateLimiter.Unlock(g.config.ShardID)
@@ -233,15 +236,17 @@ func (g *gatewayImpl) open(ctx context.Context) error {
233236
if g.config.ResumeURL != nil && g.config.EnableResumeURL {
234237
wsURL = *g.config.ResumeURL
235238
}
236-
gatewayURL := fmt.Sprintf("%s?v=%d&encoding=json", wsURL, Version)
237239

238-
switch g.config.Compression {
239-
case ZlibStreamCompression:
240-
gatewayURL += "&compress=zlib-stream"
241-
case ZstdStreamCompression:
242-
gatewayURL += "&compress=zstd-stream"
240+
values := url.Values{}
241+
values.Set("v", strconv.Itoa(Version))
242+
values.Set("encoding", "json")
243+
244+
if g.config.Compression.isStreamCompression() {
245+
values.Set("compress", string(g.config.Compression))
243246
}
244247

248+
gatewayURL := wsURL + "?" + values.Encode()
249+
245250
g.lastHeartbeatSent = time.Now().UTC()
246251
conn, rs, err := g.config.Dialer.DialContext(ctx, gatewayURL, nil)
247252
if err != nil {
@@ -258,33 +263,17 @@ func (g *gatewayImpl) open(ctx context.Context) error {
258263
}
259264

260265
g.config.Logger.ErrorContext(ctx, "error connecting to the gateway", slog.Any("err", err), slog.String("url", gatewayURL), slog.String("body", body))
261-
g.transportMu.Unlock()
266+
g.connMux.Unlock()
262267
return err
263268
}
264269

265270
conn.SetCloseHandler(func(code int, text string) error {
266271
return nil
267272
})
268273

269-
var t transport
270-
switch g.config.Compression {
271-
case ZstdStreamCompression:
272-
g.config.Logger.Debug("using zstd stream compression")
273-
t = newZstdStreamTransport(conn, g.config.Logger)
274-
case ZlibStreamCompression:
275-
g.config.Logger.Debug("using zlib stream compression")
276-
t = newZlibStreamTransport(conn, g.config.Logger)
277-
default:
278-
// zlibPayloadTransport supports both compressed (using zlib)
279-
// and uncompressed payloads
280-
//
281-
// The identify payload will state whether (some) payloads
282-
// will be compressed or not
283-
g.config.Logger.Debug("using no stream compression")
284-
t = newZlibPayloadTransport(conn, g.config.Logger)
285-
}
286-
g.transport = t
287-
g.transportMu.Unlock()
274+
g.config.Logger.DebugContext(ctx, "using compression", slog.String("compressionType", string(g.config.Compression)))
275+
g.conn = g.config.Compression.newTransport(conn, g.config.Logger)
276+
g.connMux.Unlock()
288277

289278
// reset rate limiter when connecting
290279
g.config.RateLimiter.Reset()
@@ -295,7 +284,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {
295284

296285
var readyOnce sync.Once
297286
readyChan := make(chan error)
298-
go g.listen(t, func(err error) {
287+
go g.listen(g.conn, func(err error) {
299288
readyOnce.Do(func() {
300289
readyChan <- err
301290
close(readyChan)
@@ -330,16 +319,16 @@ func (g *gatewayImpl) CloseWithCode(ctx context.Context, code int, message strin
330319
g.heartbeatCancel()
331320
}
332321

333-
g.transportMu.Lock()
334-
defer g.transportMu.Unlock()
335-
if g.transport != nil {
322+
g.connMux.Lock()
323+
defer g.connMux.Unlock()
324+
if g.conn != nil {
336325
g.config.RateLimiter.Close(ctx)
337326
g.config.Logger.DebugContext(ctx, "closing gateway connection", slog.Int("code", code), slog.String("message", message))
338-
if err := g.transport.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)); err != nil && !errors.Is(err, websocket.ErrCloseSent) {
327+
if err := g.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)); err != nil && !errors.Is(err, websocket.ErrCloseSent) {
339328
g.config.Logger.DebugContext(ctx, "error writing close code", slog.Any("err", err))
340329
}
341-
_ = g.transport.Close()
342-
g.transport = nil
330+
_ = g.conn.Close()
331+
g.conn = nil
343332

344333
// clear resume data as we closed gracefully
345334
if code == websocket.CloseNormalClosure || code == websocket.CloseGoingAway {
@@ -381,9 +370,9 @@ func (g *gatewayImpl) sendInternal(ctx context.Context, op Opcode, d MessageData
381370
}
382371

383372
func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) error {
384-
g.transportMu.Lock()
385-
defer g.transportMu.Unlock()
386-
if g.transport == nil {
373+
g.connMux.Lock()
374+
defer g.connMux.Unlock()
375+
if g.conn == nil {
387376
return discord.ErrShardNotConnected
388377
}
389378

@@ -393,7 +382,7 @@ func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) er
393382

394383
defer g.config.RateLimiter.Unlock()
395384
g.config.Logger.DebugContext(ctx, "sending gateway command", slog.String("data", string(data)))
396-
return g.transport.WriteMessage(messageType, data)
385+
return g.conn.WriteMessage(messageType, data)
397386
}
398387

399388
func (g *gatewayImpl) Latency() time.Duration {
@@ -534,7 +523,7 @@ func (g *gatewayImpl) identify() error {
534523
Browser: g.config.Browser,
535524
Device: g.config.Device,
536525
},
537-
Compress: g.config.Compression == ZlibPayloadCompression,
526+
Compress: g.config.Compression.isPayloadCompression(),
538527
LargeThreshold: g.config.LargeThreshold,
539528
Intents: g.config.Intents,
540529
Presence: g.config.Presence,
@@ -588,9 +577,9 @@ func (g *gatewayImpl) listen(transport transport, ready func(error)) {
588577
return
589578
}
590579
g.statusMu.Unlock()
591-
g.transportMu.Lock()
592-
sameConn := g.transport == transport
593-
g.transportMu.Unlock()
580+
g.connMux.Lock()
581+
sameConn := g.conn == transport
582+
g.connMux.Unlock()
594583

595584
// if sameConnection is false, it means the connection has been closed by the user, and we can just exit
596585
if !sameConn {

gateway/gateway_config.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func defaultConfig() config {
1313
LargeThreshold: 50,
1414
Intents: IntentsDefault,
1515
Compress: true,
16-
Compression: ZlibStreamCompression,
16+
Compression: CompressionZstdStream,
1717
URL: "wss://gateway.discord.gg",
1818
ShardID: 0,
1919
ShardCount: 1,
@@ -126,9 +126,9 @@ func WithIntents(intents ...Intents) ConfigOpt {
126126
func WithCompress(compress bool) ConfigOpt {
127127
return func(config *config) {
128128
if compress {
129-
config.Compression = ZlibPayloadCompression
129+
config.Compression = CompressionZlibPayload
130130
} else {
131-
config.Compression = NoCompression
131+
config.Compression = CompressionNone
132132
}
133133

134134
// Set the deprecated field too
@@ -143,7 +143,7 @@ func WithCompression(compression CompressionType) ConfigOpt {
143143
config.Compression = compression
144144

145145
// Set the deprecated field too
146-
config.Compress = compression == ZlibPayloadCompression
146+
config.Compress = compression == CompressionZlibPayload
147147
}
148148
}
149149

gateway/gateway_transports.go

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,39 @@ import (
1313
"github.com/klauspost/compress/zstd"
1414
)
1515

16-
type CompressionType int
16+
type CompressionType string
1717

1818
const (
19-
NoCompression CompressionType = iota
20-
ZlibPayloadCompression
21-
ZlibStreamCompression
22-
ZstdStreamCompression
19+
CompressionNone CompressionType = "none"
20+
CompressionZlibPayload = "zlib-payload"
21+
CompressionZlibStream = "zlib-stream"
22+
CompressionZstdStream = "zstd-stream"
2323
)
2424

25+
func (t CompressionType) isStreamCompression() bool {
26+
return t == CompressionZstdStream || t == CompressionZlibStream
27+
}
28+
29+
func (t CompressionType) isPayloadCompression() bool {
30+
return t == CompressionZlibPayload
31+
}
32+
33+
func (t CompressionType) newTransport(conn *websocket.Conn, logger *slog.Logger) transport {
34+
switch t {
35+
case CompressionZlibStream:
36+
return newZlibStreamTransport(conn, logger)
37+
case CompressionZstdStream:
38+
return newZstdStreamTransport(conn, logger)
39+
default:
40+
// zlibPayloadTransport supports both compressed (using zlib)
41+
// and uncompressed payloads
42+
//
43+
// The identify payload will state whether (some) payloads
44+
// will be compressed or not
45+
return newZlibPayloadTransport(conn, logger)
46+
}
47+
}
48+
2549
var syncFlush = []byte{0x00, 0x00, 0xff, 0xff}
2650

2751
// [zstdStreamTransport]: for connections using zstd-stream compression
@@ -153,10 +177,13 @@ func (t *zlibStreamTransport) ReceiveMessage() (Message, error, error) {
153177
}
154178
}
155179

156-
var err error
157-
158180
if t.inflator == nil {
181+
var err error
182+
159183
t.inflator, err = zlib.NewReader(t.buffer)
184+
if err != nil {
185+
return Message{}, err, nil
186+
}
160187
}
161188

162189
message, err := parseMessage(t.inflator, t.logger)

0 commit comments

Comments
 (0)