Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions _examples/proxy/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func main() {
client, err := disgo.New(token,
bot.WithShardManagerConfigOpts(
sharding.WithGatewayConfigOpts( // gateway intents are set in the proxy not here
gateway.WithURL(gatewayURL), // set the custom gateway url
gateway.WithCompress(false), // we don't want compression as that would be additional overhead
gateway.WithURL(gatewayURL), // set the custom gateway url
gateway.WithCompression(gateway.NoCompression), // we don't want compression as that would be additional overhead
),
sharding.WithIdentifyRateLimiter(gateway.NewNoopIdentifyRateLimiter()), // disable sharding rate limiter as the proxy handles it
),
Expand Down
2 changes: 1 addition & 1 deletion _examples/sharding/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func main() {
sharding.WithAutoScaling(true),
sharding.WithGatewayConfigOpts(
gateway.WithIntents(gateway.IntentGuilds, gateway.IntentGuildMessages, gateway.IntentDirectMessages),
gateway.WithCompress(true),
gateway.WithCompression(gateway.ZstdStreamCompression),
),
),
bot.WithEventListeners(&events.ListenerAdapter{
Expand Down
134 changes: 64 additions & 70 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gateway

import (
"bytes"
"compress/zlib"
"context"
"errors"
"fmt"
Expand All @@ -14,10 +13,9 @@ import (
"syscall"
"time"

"github.com/disgoorg/disgo/discord"
"github.com/disgoorg/json/v2"
"github.com/gorilla/websocket"

"github.com/disgoorg/disgo/discord"
)

// Version defines which discord API version disgo should use to connect to discord.
Expand Down Expand Up @@ -171,8 +169,8 @@ type gatewayImpl struct {
closeHandlerFunc CloseHandlerFunc
token string

conn *websocket.Conn
connMu sync.Mutex
transport transport
transportMu sync.Mutex
heartbeatCancel context.CancelFunc
status Status
statusMu sync.Mutex
Expand Down Expand Up @@ -213,9 +211,9 @@ func (g *gatewayImpl) Open(ctx context.Context) error {
func (g *gatewayImpl) open(ctx context.Context) error {
g.config.Logger.DebugContext(ctx, "opening gateway connection")

g.connMu.Lock()
if g.conn != nil {
g.connMu.Unlock()
g.transportMu.Lock()
if g.transport != nil {
g.transportMu.Unlock()
return discord.ErrGatewayAlreadyConnected
}
g.statusMu.Lock()
Expand All @@ -225,7 +223,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {
if g.config.LastSequenceReceived == nil || g.config.SessionID == nil {
if err := g.config.IdentifyRateLimiter.Wait(ctx, g.config.ShardID); err != nil {
g.config.Logger.ErrorContext(ctx, "failed to wait for identify rate limiter", slog.Any("err", err))
g.connMu.Unlock()
g.transportMu.Unlock()
return fmt.Errorf("failed to wait for identify rate limiter: %w", err)
}
defer g.config.IdentifyRateLimiter.Unlock(g.config.ShardID)
Expand All @@ -236,6 +234,14 @@ func (g *gatewayImpl) open(ctx context.Context) error {
wsURL = *g.config.ResumeURL
}
gatewayURL := fmt.Sprintf("%s?v=%d&encoding=json", wsURL, Version)

switch g.config.Compression {
case ZlibStreamCompression:
gatewayURL += "&compress=zlib-stream"
case ZstdStreamCompression:
gatewayURL += "&compress=zstd-stream"
}

g.lastHeartbeatSent = time.Now().UTC()
conn, rs, err := g.config.Dialer.DialContext(ctx, gatewayURL, nil)
if err != nil {
Expand All @@ -252,16 +258,33 @@ func (g *gatewayImpl) open(ctx context.Context) error {
}

g.config.Logger.ErrorContext(ctx, "error connecting to the gateway", slog.Any("err", err), slog.String("url", gatewayURL), slog.String("body", body))
g.connMu.Unlock()
g.transportMu.Unlock()
return err
}

conn.SetCloseHandler(func(code int, text string) error {
return nil
})

g.conn = conn
g.connMu.Unlock()
var t transport
switch g.config.Compression {
case ZstdStreamCompression:
g.config.Logger.Debug("using zstd stream compression")
t = newZstdStreamTransport(conn, g.config.Logger)
case ZlibStreamCompression:
g.config.Logger.Debug("using zlib stream compression")
t = newZlibStreamTransport(conn, g.config.Logger)
default:
// zlibPayloadTransport supports both compressed (using zlib)
// and uncompressed payloads
//
// The identify payload will state whether (some) payloads
// will be compressed or not
g.config.Logger.Debug("using no stream compression")
t = newZlibPayloadTransport(conn, g.config.Logger)
}
g.transport = t
g.transportMu.Unlock()

// reset rate limiter when connecting
g.config.RateLimiter.Reset()
Expand All @@ -272,7 +295,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {

var readyOnce sync.Once
readyChan := make(chan error)
go g.listen(conn, func(err error) {
go g.listen(t, func(err error) {
readyOnce.Do(func() {
readyChan <- err
close(readyChan)
Expand Down Expand Up @@ -307,16 +330,16 @@ func (g *gatewayImpl) CloseWithCode(ctx context.Context, code int, message strin
g.heartbeatCancel()
}

g.connMu.Lock()
defer g.connMu.Unlock()
if g.conn != nil {
g.transportMu.Lock()
defer g.transportMu.Unlock()
if g.transport != nil {
g.config.RateLimiter.Close(ctx)
g.config.Logger.DebugContext(ctx, "closing gateway connection", slog.Int("code", code), slog.String("message", message))
if err := g.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)); err != nil && !errors.Is(err, websocket.ErrCloseSent) {
if err := g.transport.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)); err != nil && !errors.Is(err, websocket.ErrCloseSent) {
g.config.Logger.DebugContext(ctx, "error writing close code", slog.Any("err", err))
}
_ = g.conn.Close()
g.conn = nil
_ = g.transport.Close()
g.transport = nil

// clear resume data as we closed gracefully
if code == websocket.CloseNormalClosure || code == websocket.CloseGoingAway {
Expand Down Expand Up @@ -358,9 +381,9 @@ func (g *gatewayImpl) sendInternal(ctx context.Context, op Opcode, d MessageData
}

func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) error {
g.connMu.Lock()
defer g.connMu.Unlock()
if g.conn == nil {
g.transportMu.Lock()
defer g.transportMu.Unlock()
if g.transport == nil {
return discord.ErrShardNotConnected
}

Expand All @@ -370,7 +393,7 @@ func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) er

defer g.config.RateLimiter.Unlock()
g.config.Logger.DebugContext(ctx, "sending gateway command", slog.String("data", string(data)))
return g.conn.WriteMessage(messageType, data)
return g.transport.WriteMessage(messageType, data)
}

func (g *gatewayImpl) Latency() time.Duration {
Expand Down Expand Up @@ -511,7 +534,7 @@ func (g *gatewayImpl) identify() error {
Browser: g.config.Browser,
Device: g.config.Device,
},
Compress: g.config.Compress,
Compress: g.config.Compression == ZlibPayloadCompression,
LargeThreshold: g.config.LargeThreshold,
Intents: g.config.Intents,
Presence: g.config.Presence,
Expand Down Expand Up @@ -549,25 +572,25 @@ func (g *gatewayImpl) resume() error {
return nil
}

func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
func (g *gatewayImpl) listen(transport transport, ready func(error)) {
defer g.config.Logger.Debug("exiting listen goroutine")

// Ensure that we never leave this function without calling ready
defer ready(nil)

for {
mt, r, err := conn.NextReader()
if err != nil {
message, connErr, parseErr := transport.ReceiveMessage()
if connErr != nil {
g.statusMu.Lock()
if g.status != StatusReady {
g.statusMu.Unlock()
ready(err)
ready(connErr)
return
}
g.statusMu.Unlock()
g.connMu.Lock()
sameConn := g.conn == conn
g.connMu.Unlock()
g.transportMu.Lock()
sameConn := g.transport == transport
g.transportMu.Unlock()

// if sameConnection is false, it means the connection has been closed by the user, and we can just exit
if !sameConn {
Expand All @@ -576,7 +599,7 @@ func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {

reconnect := true
var closeError *websocket.CloseError
if errors.As(err, &closeError) {
if errors.As(connErr, &closeError) {
closeCode := CloseEventCodeByCode(closeError.Code)
reconnect = closeCode.Reconnect

Expand All @@ -596,11 +619,11 @@ func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
} else {
g.config.Logger.Error(msg, args...)
}
} else if errors.Is(err, net.ErrClosed) {
} else if errors.Is(connErr, net.ErrClosed) {
// we closed the connection ourselves. Don't try to reconnect here
reconnect = false
} else {
g.config.Logger.Warn("failed to read next message from gateway", slog.Any("err", err))
g.config.Logger.Warn("failed to read next message from gateway", slog.Any("err", connErr))
}

// make sure the connection is properly closed
Expand All @@ -610,15 +633,13 @@ func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
if g.config.AutoReconnect && reconnect {
go g.reconnect()
} else if g.closeHandlerFunc != nil {
go g.closeHandlerFunc(g, err, reconnect)
go g.closeHandlerFunc(g, connErr, reconnect)
}

return
}

message, err := g.parseMessage(mt, r)
if err != nil {
g.config.Logger.Error("error while parsing gateway message", slog.Any("err", err))
if parseErr != nil {
g.config.Logger.Error("error while parsing gateway message", slog.Any("err", parseErr))
continue
}

Expand All @@ -629,12 +650,12 @@ func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
go g.heartbeat()

if g.config.LastSequenceReceived == nil || g.config.SessionID == nil {
err = g.identify()
connErr = g.identify()
} else {
err = g.resume()
connErr = g.resume()
}
if err != nil {
ready(err)
if connErr != nil {
ready(connErr)
return
}

Expand Down Expand Up @@ -741,30 +762,3 @@ func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
}
}
}

func (g *gatewayImpl) parseMessage(mt int, r io.Reader) (Message, error) {
if mt == websocket.BinaryMessage {
g.config.Logger.Debug("binary message received. decompressing")

reader, err := zlib.NewReader(r)
if err != nil {
return Message{}, fmt.Errorf("failed to decompress zlib: %w", err)
}
defer reader.Close()
r = reader
}

if g.config.Logger.Enabled(context.Background(), slog.LevelDebug) {
buff := new(bytes.Buffer)
tr := io.TeeReader(r, buff)
data, err := io.ReadAll(tr)
if err != nil {
return Message{}, fmt.Errorf("failed to read message: %w", err)
}
g.config.Logger.Debug("received gateway message", slog.String("data", string(data)))
r = buff
}

var message Message
return message, json.NewDecoder(r).Decode(&message)
}
23 changes: 23 additions & 0 deletions gateway/gateway_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func defaultConfig() config {
LargeThreshold: 50,
Intents: IntentsDefault,
Compress: true,
Compression: ZlibStreamCompression,
URL: "wss://gateway.discord.gg",
ShardID: 0,
ShardCount: 1,
Expand All @@ -33,7 +34,10 @@ type config struct {
// Intents is the Intents for the Gateway. Defaults to IntentsNone.
Intents Intents
// Compress is whether the Gateway should compress payloads. Defaults to true.
// Deprecated: Use Compression instead
Compress bool
// Compression is the compression type to use for the gateway. Defaults to ZstdCompression.
Compression CompressionType
// URL is the URL of the Gateway. Defaults to fetch from Discord.
URL string
// ShardID is the shardID of the Gateway. Defaults to 0.
Expand Down Expand Up @@ -118,12 +122,31 @@ func WithIntents(intents ...Intents) ConfigOpt {

// WithCompress sets whether this Gateway supports compression.
// See here for more information: https://discord.com/developers/docs/topics/gateway#encoding-and-compression
// Deprecated: Use WithCompression instead
func WithCompress(compress bool) ConfigOpt {
return func(config *config) {
if compress {
config.Compression = ZlibPayloadCompression
} else {
config.Compression = NoCompression
}

// Set the deprecated field too
config.Compress = compress
}
}

// WithCompression sets the compression mechanism to use.
// See here for more information: https://discord.com/developers/docs/topics/gateway#encoding-and-compression
func WithCompression(compression CompressionType) ConfigOpt {
return func(config *config) {
config.Compression = compression

// Set the deprecated field too
config.Compress = compression == ZlibPayloadCompression
}
}

// WithURL sets the Gateway URL for the Gateway.
func WithURL(url string) ConfigOpt {
return func(config *config) {
Expand Down
Loading