@@ -3,7 +3,9 @@ package lavalink
33import (
44 "context"
55 "encoding/json"
6+ "errors"
67 "fmt"
8+ "net"
79 "net/http"
810 "sync"
911 "time"
@@ -120,37 +122,55 @@ func (n *nodeImpl) Stats() *Stats {
120122 return n .stats
121123}
122124
123- func (n * nodeImpl ) reconnect () error {
125+ func (n * nodeImpl ) reconnect (ctx context.Context ) {
126+ if err := n .reconnectTry (ctx , 0 , time .Second ); err != nil {
127+ n .lavalink .Logger ().Error ("failed to reconnect to node: " , err )
128+ }
129+ }
130+
131+ func (n * nodeImpl ) reconnectTry (ctx context.Context , try int , delay time.Duration ) error {
124132 n .statusMu .Lock ()
125133 defer n .statusMu .Unlock ()
126-
127134 n .status = Reconnecting
128- if err := n .open (context .TODO (), 0 ); err != nil {
135+
136+ timer := time .NewTimer (time .Duration (try ) * delay )
137+ defer timer .Stop ()
138+ select {
139+ case <- ctx .Done ():
140+ timer .Stop ()
141+ return ctx .Err ()
142+ case <- timer .C :
143+ }
144+
145+ n .lavalink .Logger ().Debug ("reconnecting gateway..." )
146+ if err := n .open (ctx ); err != nil {
147+ n .lavalink .Logger ().Error ("failed to reconnect node. error: " , err )
129148 n .status = Disconnected
130- return err
149+ return n . reconnectTry ( ctx , try + 1 , delay )
131150 }
132151 n .status = Connected
133152 return nil
134153}
135154
136155func (n * nodeImpl ) listen () {
137- defer func () {
138- n .lavalink .Logger ().Info ("shut down listen goroutine" )
139- }()
156+ defer n .lavalink .Logger ().Debug ("shutting down listen goroutine" )
157+ loop:
140158 for {
141159 if n .conn == nil {
142160 return
143161 }
144162 _ , data , err := n .conn .ReadMessage ()
145163 if err != nil {
146- if websocket .IsUnexpectedCloseError (err , websocket .CloseGoingAway , websocket .CloseAbnormalClosure ) {
147- n .lavalink .Logger ().Error ("error while reading from lavalink websocket. error: " , err )
148- n .Close ()
149- if err = n .reconnect (); err != nil {
150- n .lavalink .Logger ().Error ("error while reconnecting to lavalink websocket. error: " , err )
151- }
164+ reconnect := true
165+ if errors .Is (err , net .ErrClosed ) {
166+ // we closed the connection manually, so we don't want to reconnect
167+ reconnect = false
152168 }
153- return
169+ n .Close ()
170+ if reconnect {
171+ go n .reconnect (context .TODO ())
172+ }
173+ break loop
154174 }
155175
156176 n .lavalink .Logger ().Trace ("received: " , string (data ))
@@ -287,7 +307,7 @@ func (n *nodeImpl) onStatsEvent(stats StatsOp) {
287307 n .stats = & stats .Stats
288308}
289309
290- func (n * nodeImpl ) open (ctx context.Context , delay time. Duration ) error {
310+ func (n * nodeImpl ) open (ctx context.Context ) error {
291311 select {
292312 case <- ctx .Done ():
293313 return ctx .Err ()
@@ -312,21 +332,7 @@ func (n *nodeImpl) open(ctx context.Context, delay time.Duration) error {
312332 )
313333 n .conn , rs , err = websocket .DefaultDialer .DialContext (ctx , fmt .Sprintf ("%s://%s:%s" , scheme , n .config .Host , n .config .Port ), header )
314334 if err != nil {
315- n .lavalink .Logger ().Warnf ("error while connecting to lavalink websocket, retrying in %f seconds: %s" , delay .Seconds (), err )
316- if delay > 0 {
317- select {
318- case <- ctx .Done ():
319- return ctx .Err ()
320- case <- time .After (delay ):
321- }
322-
323- } else {
324- delay = 1 * time .Second
325- }
326- if delay < 30 * time .Second {
327- delay *= 2
328- }
329- return n .open (ctx , delay )
335+ return err
330336 }
331337 if n .config .ResumingKey != "" {
332338 if rs .Header .Get ("Session-Resumed" ) == "true" {
@@ -352,7 +358,7 @@ func (n *nodeImpl) Open(ctx context.Context) error {
352358 defer n .statusMu .Unlock ()
353359
354360 n .status = Connecting
355- if err := n .open (ctx , 0 ); err != nil {
361+ if err := n .open (ctx ); err != nil {
356362 n .status = Disconnected
357363 return err
358364 }
0 commit comments