Skip to content

Commit 11b1b9d

Browse files
authored
Add ReconnectClient (#18)
* Add ReconnectClient * Temporary implement incremental wait * Implement auto resubscribe * Add Client.Handle
1 parent 58be7b1 commit 11b1b9d

File tree

9 files changed

+284
-36
lines changed

9 files changed

+284
-36
lines changed

client.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@ import (
1010
type BaseClient struct {
1111
// Transport is an underlying connection. Typically net.Conn.
1212
Transport io.ReadWriteCloser
13-
// Handler of incoming messages.
14-
Handler Handler
1513
// ConnState is called if the connection state is changed.
1614
ConnState func(ConnState, error)
1715

16+
handler Handler
1817
sig *signaller
1918
mu sync.RWMutex
2019
connState ConnState
@@ -24,6 +23,13 @@ type BaseClient struct {
2423
idLast uint32
2524
}
2625

26+
// Handle registers the message handler.
27+
func (c *BaseClient) Handle(handler Handler) {
28+
c.mu.Lock()
29+
defer c.mu.Unlock()
30+
c.handler = handler
31+
}
32+
2733
// WithUserNamePassword sets plain text auth information used in Connect.
2834
func WithUserNamePassword(userName, password string) ConnectOption {
2935
return func(o *ConnectOptions) error {

client_integration_test.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,17 @@ func ExampleClient() {
2727
if err != nil {
2828
panic(err)
2929
}
30-
baseCli.Handler = HandlerFunc(func(msg *Message) {
31-
fmt.Printf("%s[%d]: %s", msg.Topic, int(msg.QoS), []byte(msg.Payload))
32-
close(done)
33-
})
3430

3531
// store as Client to make it easy to enable high level wrapper later
3632
var cli Client = baseCli
3733
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
3834
defer cancel()
3935

36+
cli.Handle(HandlerFunc(func(msg *Message) {
37+
fmt.Printf("%s[%d]: %s", msg.Topic, int(msg.QoS), []byte(msg.Payload))
38+
close(done)
39+
}))
40+
4041
if _, err := cli.Connect(ctx, "TestClient", WithCleanSession(true)); err != nil {
4142
panic(err)
4243
}
@@ -123,9 +124,6 @@ func TestIntegration_PublishQoS2_SubscribeQoS2(t *testing.T) {
123124
}
124125

125126
chReceived := make(chan *Message, 100)
126-
cli.Handler = HandlerFunc(func(msg *Message) {
127-
chReceived <- msg
128-
})
129127
cli.ConnState = func(s ConnState, err error) {
130128
switch s {
131129
case StateActive:
@@ -142,6 +140,10 @@ func TestIntegration_PublishQoS2_SubscribeQoS2(t *testing.T) {
142140
t.Fatalf("Unexpected error: '%v'", err)
143141
}
144142

143+
cli.Handle(HandlerFunc(func(msg *Message) {
144+
chReceived <- msg
145+
}))
146+
145147
if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil {
146148
t.Fatalf("Unexpected error: '%v'", err)
147149
}
@@ -155,6 +157,8 @@ func TestIntegration_PublishQoS2_SubscribeQoS2(t *testing.T) {
155157
}
156158

157159
select {
160+
case <-ctx.Done():
161+
t.Fatalf("Unexpected error: '%v'", ctx.Err())
158162
case msg, ok := <-chReceived:
159163
if !ok {
160164
t.Errorf("Connection closed unexpectedly")
@@ -238,9 +242,6 @@ func BenchmarkPublishSubscribe(b *testing.B) {
238242
}
239243

240244
chReceived := make(chan *Message, 100)
241-
cli.Handler = HandlerFunc(func(msg *Message) {
242-
chReceived <- msg
243-
})
244245
cli.ConnState = func(s ConnState, err error) {
245246
switch s {
246247
case StateActive:
@@ -256,6 +257,10 @@ func BenchmarkPublishSubscribe(b *testing.B) {
256257
b.Fatalf("Unexpected error: '%v'", err)
257258
}
258259

260+
cli.Handle(HandlerFunc(func(msg *Message) {
261+
chReceived <- msg
262+
}))
263+
259264
if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil {
260265
b.Fatalf("Unexpected error: '%v'", err)
261266
}

conn.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@ var ErrUnsupportedProtocol = errors.New("unsupported protocol")
1616
// ErrClosedTransport means that the underlying connection is closed.
1717
var ErrClosedTransport = errors.New("read/write on closed transport")
1818

19+
// URLDialer is a Dialer using URL string.
20+
type URLDialer struct {
21+
URL string
22+
Options []DialOption
23+
}
24+
25+
// Dialer is an interface to create connection.
26+
type Dialer interface {
27+
Dial() (ClientCloser, error)
28+
}
29+
30+
// Dial creates connection using its values.
31+
func (d *URLDialer) Dial() (ClientCloser, error) {
32+
return Dial(d.URL, d.Options...)
33+
}
34+
1935
// Dial creates MQTT client using URL string.
2036
func Dial(urlStr string, opts ...DialOption) (*BaseClient, error) {
2137
o := &DialOptions{
@@ -113,3 +129,15 @@ func (c *BaseClient) connStateUpdate(newState ConnState) {
113129
func (c *BaseClient) Close() error {
114130
return c.Transport.Close()
115131
}
132+
133+
// Done is a channel to signal connection close.
134+
func (c *BaseClient) Done() <-chan struct{} {
135+
return c.connClosed
136+
}
137+
138+
// Err returns connection error.
139+
func (c *BaseClient) Err() error {
140+
c.mu.Lock()
141+
defer c.mu.Unlock()
142+
return c.err
143+
}

mqtt.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,14 @@ type Client interface {
5151
Subscribe(ctx context.Context, subs ...Subscription) error
5252
Unsubscribe(ctx context.Context, subs ...string) error
5353
Ping(ctx context.Context) error
54+
Handle(Handler)
5455
}
5556

5657
// Closer is the interface of connection closer.
5758
type Closer interface {
5859
Close() error
60+
Done() <-chan struct{}
61+
Err() error
5962
}
6063

6164
// ClientCloser groups Client and Closer interface

reconnclient.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package mqtt
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
)
8+
9+
type reconnectClient struct {
10+
Client
11+
}
12+
13+
// NewReconnectClient creates a MQTT client with re-connect/re-publish/re-subscribe features.
14+
func NewReconnectClient(ctx context.Context, dialer Dialer, clientID string, opts ...ConnectOption) Client {
15+
rc := &RetryClient{}
16+
cli := &reconnectClient{
17+
Client: rc,
18+
}
19+
done := make(chan struct{})
20+
var doneOnce sync.Once
21+
go func() {
22+
clean := true
23+
reconnWaitBase := 50 * time.Millisecond
24+
reconnWaitMax := 10 * time.Second
25+
reconnWait := reconnWaitBase
26+
for {
27+
if c, err := dialer.Dial(); err == nil {
28+
optsCurr := append([]ConnectOption{}, opts...)
29+
optsCurr = append(optsCurr, WithCleanSession(clean))
30+
clean = false // Clean only first time.
31+
reconnWait = reconnWaitBase // Reset reconnect wait.
32+
rc.SetClient(ctx, c)
33+
34+
if present, err := rc.Connect(ctx, clientID, optsCurr...); err == nil {
35+
doneOnce.Do(func() { close(done) })
36+
if present {
37+
rc.Resubscribe(ctx)
38+
}
39+
// Start keep alive.
40+
go func() {
41+
_ = KeepAlive(ctx, c, time.Second, time.Second)
42+
}()
43+
select {
44+
case <-c.Done():
45+
if err := c.Err(); err == nil {
46+
// Disconnected as expected; don't restart.
47+
return
48+
}
49+
case <-ctx.Done():
50+
// User cancelled; don't restart.
51+
return
52+
}
53+
}
54+
}
55+
select {
56+
case <-time.After(reconnWait):
57+
case <-ctx.Done():
58+
// User cancelled; don't restart.
59+
return
60+
}
61+
reconnWait *= 2
62+
if reconnWait > reconnWaitMax {
63+
reconnWait = reconnWaitMax
64+
}
65+
}
66+
}()
67+
select {
68+
case <-done:
69+
case <-ctx.Done():
70+
}
71+
return cli
72+
}

reconnclient_integration_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// +build integration
2+
3+
package mqtt
4+
5+
import (
6+
"context"
7+
"crypto/tls"
8+
"testing"
9+
"time"
10+
)
11+
12+
func TestIntegration_ReconnectClient(t *testing.T) {
13+
for name, url := range urls {
14+
t.Run(name, func(t *testing.T) {
15+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
16+
defer cancel()
17+
18+
chReceived := make(chan *Message, 100)
19+
cli := NewReconnectClient(
20+
ctx,
21+
&URLDialer{
22+
URL: url,
23+
Options: []DialOption{
24+
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
25+
},
26+
},
27+
"ReconnectClient",
28+
)
29+
cli.Handle(HandlerFunc(func(msg *Message) {
30+
chReceived <- msg
31+
}))
32+
33+
// Close underlying client.
34+
cli.(*reconnectClient).Client.(*RetryClient).Client.(ClientCloser).Close()
35+
36+
if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS1}); err != nil {
37+
t.Fatalf("Unexpected error: '%v'", err)
38+
}
39+
if err := cli.Publish(ctx, &Message{
40+
Topic: "test",
41+
QoS: QoS1,
42+
Retain: true,
43+
Payload: []byte("message"),
44+
}); err != nil {
45+
t.Fatalf("Unexpected error: '%v'", err)
46+
}
47+
48+
time.Sleep(time.Second)
49+
50+
select {
51+
case <-ctx.Done():
52+
t.Fatalf("Unexpected error: '%v'", ctx.Err())
53+
case <-chReceived:
54+
}
55+
})
56+
}
57+
58+
}

0 commit comments

Comments
 (0)