Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions _examples/proxy/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func main() {
client, err := disgo.New(token,
bot.WithShardManagerConfigOpts(
sharding.WithGatewayConfigOpts( // gateway intents are set in the proxy not here
gateway.WithURL(gatewayURL), // set the custom gateway url
gateway.WithCompress(false), // we don't want compression as that would be additional overhead
gateway.WithURL(gatewayURL), // set the custom gateway url
gateway.WithCompression(gateway.CompressionNone), // we don't want compression as that would be additional overhead
),
sharding.WithIdentifyRateLimiter(gateway.NewNoopIdentifyRateLimiter()), // disable sharding rate limiter as the proxy handles it
),
Expand Down
2 changes: 1 addition & 1 deletion _examples/sharding/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func main() {
sharding.WithAutoScaling(true),
sharding.WithGatewayConfigOpts(
gateway.WithIntents(gateway.IntentGuilds, gateway.IntentGuildMessages, gateway.IntentDirectMessages),
gateway.WithCompress(true),
gateway.WithCompression(gateway.CompressionZstdStream),
),
),
bot.WithEventListeners(&events.ListenerAdapter{
Expand Down
77 changes: 26 additions & 51 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ package gateway

import (
"bytes"
"compress/zlib"
"context"
"errors"
"fmt"
"io"
"log/slog"
"math/rand/v2"
"net"
"net/url"
"strconv"
"sync"
"syscall"
"time"

"github.com/disgoorg/json/v2"
"github.com/gorilla/websocket"

"github.com/disgoorg/disgo/discord"
Expand Down Expand Up @@ -171,7 +171,7 @@ type gatewayImpl struct {
closeHandlerFunc CloseHandlerFunc
token string

conn *websocket.Conn
conn transport
connMu sync.Mutex
heartbeatCancel context.CancelFunc
status Status
Expand Down Expand Up @@ -211,7 +211,7 @@ func (g *gatewayImpl) Open(ctx context.Context) error {
}

func (g *gatewayImpl) open(ctx context.Context) error {
g.config.Logger.DebugContext(ctx, "opening gateway connection")
g.config.Logger.DebugContext(ctx, "opening gateway connection", slog.String("compression", g.config.Compression.String()))

g.connMu.Lock()
if g.conn != nil {
Expand All @@ -235,7 +235,17 @@ func (g *gatewayImpl) open(ctx context.Context) error {
if g.config.ResumeURL != nil && g.config.EnableResumeURL {
wsURL = *g.config.ResumeURL
}
gatewayURL := fmt.Sprintf("%s?v=%d&encoding=json", wsURL, Version)

values := url.Values{}
values.Set("v", strconv.Itoa(Version))
values.Set("encoding", "json")

if g.config.Compression.IsStreamCompression() {
values.Set("compress", string(g.config.Compression))
}

gatewayURL := wsURL + "?" + values.Encode()

g.lastHeartbeatSent = time.Now().UTC()
conn, rs, err := g.config.Dialer.DialContext(ctx, gatewayURL, nil)
if err != nil {
Expand All @@ -260,7 +270,8 @@ func (g *gatewayImpl) open(ctx context.Context) error {
return nil
})

g.conn = conn
t := newTransport(g.config.Compression, conn, g.config.Logger)
g.conn = t
g.connMu.Unlock()

// reset rate limiter when connecting
Expand All @@ -272,7 +283,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {

var readyOnce sync.Once
readyChan := make(chan error)
go g.listen(conn, func(err error) {
go g.listen(t, func(err error) {
readyOnce.Do(func() {
readyChan <- err
close(readyChan)
Expand Down Expand Up @@ -312,7 +323,7 @@ func (g *gatewayImpl) CloseWithCode(ctx context.Context, code int, message strin
if g.conn != nil {
g.config.RateLimiter.Close(ctx)
g.config.Logger.DebugContext(ctx, "closing gateway connection", slog.Int("code", code), slog.String("message", message))
if err := g.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)); err != nil && !errors.Is(err, websocket.ErrCloseSent) {
if err := g.conn.WriteClose(code, message); err != nil && !errors.Is(err, websocket.ErrCloseSent) {
g.config.Logger.DebugContext(ctx, "error writing close code", slog.Any("err", err))
}
_ = g.conn.Close()
Expand Down Expand Up @@ -347,17 +358,11 @@ func (g *gatewayImpl) Send(ctx context.Context, op Opcode, d MessageData) error
}

func (g *gatewayImpl) sendInternal(ctx context.Context, op Opcode, d MessageData) error {
data, err := json.Marshal(Message{
data := Message{
Op: op,
D: d,
})
if err != nil {
return err
}
return g.send(ctx, websocket.TextMessage, data)
}

func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) error {
g.connMu.Lock()
defer g.connMu.Unlock()
if g.conn == nil {
Expand All @@ -369,8 +374,7 @@ func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) er
}

defer g.config.RateLimiter.Unlock()
g.config.Logger.DebugContext(ctx, "sending gateway command", slog.String("data", string(data)))
return g.conn.WriteMessage(messageType, data)
return g.conn.WriteMessage(data)
}

func (g *gatewayImpl) Latency() time.Duration {
Expand Down Expand Up @@ -511,7 +515,7 @@ func (g *gatewayImpl) identify() error {
Browser: g.config.Browser,
Device: g.config.Device,
},
Compress: g.config.Compress,
Compress: g.config.Compression.IsPayloadCompression(),
LargeThreshold: g.config.LargeThreshold,
Intents: g.config.Intents,
Presence: g.config.Presence,
Expand Down Expand Up @@ -549,14 +553,14 @@ func (g *gatewayImpl) resume() error {
return nil
}

func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
func (g *gatewayImpl) listen(conn transport, ready func(error)) {
defer g.config.Logger.Debug("exiting listen goroutine")

// Ensure that we never leave this function without calling ready
defer ready(nil)

for {
mt, r, err := conn.NextReader()
message, err := conn.ReceiveMessage()
if err != nil {
g.statusMu.Lock()
if g.status != StatusReady {
Expand Down Expand Up @@ -615,10 +619,8 @@ func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {

return
}

message, err := g.parseMessage(mt, r)
if err != nil {
g.config.Logger.Error("error while parsing gateway message", slog.Any("err", err))
if message == nil {
// No message (probably parsing error), just continue as the transport already logged it
continue
}

Expand Down Expand Up @@ -741,30 +743,3 @@ func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
}
}
}

func (g *gatewayImpl) parseMessage(mt int, r io.Reader) (Message, error) {
if mt == websocket.BinaryMessage {
g.config.Logger.Debug("binary message received. decompressing")

reader, err := zlib.NewReader(r)
if err != nil {
return Message{}, fmt.Errorf("failed to decompress zlib: %w", err)
}
defer reader.Close()
r = reader
}

if g.config.Logger.Enabled(context.Background(), slog.LevelDebug) {
buff := new(bytes.Buffer)
tr := io.TeeReader(r, buff)
data, err := io.ReadAll(tr)
if err != nil {
return Message{}, fmt.Errorf("failed to read message: %w", err)
}
g.config.Logger.Debug("received gateway message", slog.String("data", string(data)))
r = buff
}

var message Message
return message, json.NewDecoder(r).Decode(&message)
}
23 changes: 23 additions & 0 deletions gateway/gateway_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func defaultConfig() config {
LargeThreshold: 50,
Intents: IntentsDefault,
Compress: true,
Compression: CompressionZstdStream,
URL: "wss://gateway.discord.gg",
ShardID: 0,
ShardCount: 1,
Expand All @@ -33,7 +34,10 @@ type config struct {
// Intents is the Intents for the Gateway. Defaults to IntentsNone.
Intents Intents
// Compress is whether the Gateway should compress payloads. Defaults to true.
// Deprecated: Use Compression instead
Compress bool
// Compression is the compression type to use for the gateway. Defaults to ZstdCompression.
Compression CompressionType
// URL is the URL of the Gateway. Defaults to fetch from Discord.
URL string
// ShardID is the shardID of the Gateway. Defaults to 0.
Expand Down Expand Up @@ -118,12 +122,31 @@ func WithIntents(intents ...Intents) ConfigOpt {

// WithCompress sets whether this Gateway supports compression.
// See here for more information: https://discord.com/developers/docs/topics/gateway#encoding-and-compression
// Deprecated: Use WithCompression instead
func WithCompress(compress bool) ConfigOpt {
return func(config *config) {
if compress {
config.Compression = CompressionZlibPayload
} else {
config.Compression = CompressionNone
}

// Set the deprecated field too
config.Compress = compress
}
}

// WithCompression sets the compression mechanism to use.
// See here for more information: https://discord.com/developers/docs/topics/gateway#encoding-and-compression
func WithCompression(compression CompressionType) ConfigOpt {
return func(config *config) {
config.Compression = compression

// Set the deprecated field too
config.Compress = compression == CompressionZlibPayload
}
}

// WithURL sets the Gateway URL for the Gateway.
func WithURL(url string) ConfigOpt {
return func(config *config) {
Expand Down
Loading