Skip to content

Commit db962e1

Browse files
authored
Add DialOption to set ConnState handler (#67)
1 parent 7975b1a commit db962e1

File tree

4 files changed

+48
-54
lines changed

4 files changed

+48
-54
lines changed

client_integration_test.go

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,22 @@ func TestIntegration_PublishSubscribe(t *testing.T) {
134134
t.Run(name, func(t *testing.T) {
135135
for _, qos := range []QoS{QoS0, QoS1, QoS2} {
136136
t.Run(fmt.Sprintf("QoS%d", int(qos)), func(t *testing.T) {
137-
cli, err := Dial(url, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
137+
chReceived := make(chan *Message, 100)
138+
139+
cli, err := Dial(url,
140+
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
141+
WithConnStateHandler(func(s ConnState, err error) {
142+
switch s {
143+
case StateClosed:
144+
close(chReceived)
145+
t.Errorf("Connection is expected to be disconnected, but closed.")
146+
}
147+
}),
148+
)
138149
if err != nil {
139150
t.Fatalf("Unexpected error: '%v'", err)
140151
}
141152

142-
chReceived := make(chan *Message, 100)
143-
cli.ConnState = func(s ConnState, err error) {
144-
switch s {
145-
case StateActive:
146-
case StateClosed:
147-
close(chReceived)
148-
t.Errorf("Connection is expected to be disconnected, but closed.")
149-
case StateDisconnected:
150-
}
151-
}
152-
153153
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
154154
defer cancel()
155155
if _, err := cli.Connect(ctx, "PubSubClient"+name); err != nil {
@@ -254,21 +254,21 @@ func TestIntegration_Ping(t *testing.T) {
254254
func BenchmarkPublishSubscribe(b *testing.B) {
255255
for name, url := range urls {
256256
b.Run(name, func(b *testing.B) {
257-
cli, err := Dial(url, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
257+
chReceived := make(chan *Message, 100)
258+
259+
cli, err := Dial(url,
260+
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
261+
WithConnStateHandler(func(s ConnState, err error) {
262+
switch s {
263+
case StateClosed:
264+
close(chReceived)
265+
}
266+
}),
267+
)
258268
if err != nil {
259269
b.Fatalf("Unexpected error: '%v'", err)
260270
}
261271

262-
chReceived := make(chan *Message, 100)
263-
cli.ConnState = func(s ConnState, err error) {
264-
switch s {
265-
case StateActive:
266-
case StateClosed:
267-
close(chReceived)
268-
case StateDisconnected:
269-
}
270-
}
271-
272272
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
273273
defer cancel()
274274
if _, err := cli.Connect(ctx, "PubSubBenchClient"+name); err != nil {

conn.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ type DialOption func(*DialOptions) error
7474
type DialOptions struct {
7575
Dialer *net.Dialer
7676
TLSConfig *tls.Config
77+
ConnState func(ConnState, error)
7778
}
7879

7980
// WithDialer sets dialer.
@@ -92,8 +93,18 @@ func WithTLSConfig(config *tls.Config) DialOption {
9293
}
9394
}
9495

96+
// WithConnStateHandler sets connection state change handler.
97+
func WithConnStateHandler(handler func(ConnState, error)) DialOption {
98+
return func(o *DialOptions) error {
99+
o.ConnState = handler
100+
return nil
101+
}
102+
}
103+
95104
func (d *DialOptions) dial(urlStr string) (*BaseClient, error) {
96-
c := &BaseClient{}
105+
c := &BaseClient{
106+
ConnState: d.ConnState,
107+
}
97108

98109
u, err := url.Parse(urlStr)
99110
if err != nil {

examples/mqtts-client-cert/main.go

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,28 +52,16 @@ func main() {
5252

5353
cli, err := mqtt.NewReconnectClient(
5454
// Dialer to connect/reconnect to the server.
55-
mqtt.DialerFunc(func() (mqtt.ClientCloser, error) {
56-
cli, err := mqtt.Dial(
57-
fmt.Sprintf("mqtts://%s:8883", host),
55+
&mqtt.URLDialer{
56+
URL: fmt.Sprintf("mqtts://%s:8883", host),
57+
Options: []mqtt.DialOption{
5858
mqtt.WithTLSConfig(tlsConfig),
59-
)
60-
if err != nil {
61-
return nil, err
62-
}
63-
// Register ConnState callback to low level client
64-
cli.ConnState = func(s mqtt.ConnState, err error) {
65-
fmt.Printf("State changed to %s (err: %v)\n", s, err)
66-
}
67-
return cli, nil
68-
}),
69-
// If you don't need customized (with state callback) low layer client,
70-
// just use mqtt.URLDialer:
71-
// &mqtt.URLDialer{
72-
// URL: fmt.Sprintf("mqtts://%s:8883", host),
73-
// Options: []mqtt.DialOption{
74-
// mqtt.WithTLSConfig(tlsConfig),
75-
// },
76-
// },
59+
mqtt.WithConnStateHandler(func(s mqtt.ConnState, err error) {
60+
// Register ConnState callback to low level client
61+
fmt.Printf("State changed to %s (err: %v)\n", s, err)
62+
}),
63+
},
64+
},
7765
mqtt.WithPingInterval(10*time.Second),
7866
mqtt.WithTimeout(5*time.Second),
7967
mqtt.WithReconnectWait(1*time.Second, 15*time.Second),

examples/wss-presign-url/main.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,12 @@ func main() {
4444
host, time.Now().UnixNano(),
4545
)
4646
println("New URL:", url)
47-
cli, err := mqtt.Dial(url,
47+
return mqtt.Dial(url,
4848
mqtt.WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
49+
mqtt.WithConnStateHandler(func(s mqtt.ConnState, err error) {
50+
fmt.Printf("State changed to %s (err: %v)\n", s, err)
51+
}),
4952
)
50-
if err != nil {
51-
return nil, err
52-
}
53-
// Register ConnState callback to low level client
54-
cli.ConnState = func(s mqtt.ConnState, err error) {
55-
fmt.Printf("State changed to %s (err: %v)\n", s, err)
56-
}
57-
return cli, nil
5853
}),
5954
mqtt.WithPingInterval(10*time.Second),
6055
mqtt.WithTimeout(5*time.Second),

0 commit comments

Comments
 (0)