Skip to content

Commit d51bcdd

Browse files
berkantJoeTurki
authored andcommitted
Handle stale nonces in ChannelBind
And add timer to renew bindings
1 parent 903cf17 commit d51bcdd

File tree

4 files changed

+225
-73
lines changed

4 files changed

+225
-73
lines changed

internal/client/binding.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ func (b *binding) refreshedAt() time.Time {
6161
return b._refreshedAt
6262
}
6363

64+
func (b *binding) ok() bool {
65+
state := b.state()
66+
67+
return state == bindingStateReady || state == bindingStateRefresh
68+
}
69+
6470
// Thread-safe binding map.
6571
type bindingManager struct {
6672
chanMap map[uint16]*binding
@@ -159,3 +165,15 @@ func (mgr *bindingManager) size() int {
159165

160166
return len(mgr.chanMap)
161167
}
168+
169+
func (mgr *bindingManager) all() []*binding {
170+
mgr.mutex.RLock()
171+
defer mgr.mutex.RUnlock()
172+
173+
list := make([]*binding, 0, len(mgr.chanMap))
174+
for _, b := range mgr.chanMap {
175+
list = append(list, b)
176+
}
177+
178+
return list
179+
}

internal/client/binding_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ func TestBindingManager(t *testing.T) {
4646
assert.Equal(t, b0, b2, "should match")
4747
}
4848

49+
all := bm.all()
50+
for _, b := range all {
51+
found, ok := bm.findByNumber(b.number)
52+
assert.True(t, ok, "should exist")
53+
assert.Equal(t, b, found, "should match")
54+
}
55+
assert.Equal(t, count, len(all), "should match")
4956
assert.Equal(t, count, bm.size(), "should match")
5057
assert.Equal(t, count, len(bm.addrMap), "should match")
5158

@@ -60,6 +67,7 @@ func TestBindingManager(t *testing.T) {
6067

6168
assert.Equal(t, 0, bm.size(), "should match")
6269
assert.Equal(t, 0, len(bm.addrMap), "should match")
70+
assert.Equal(t, 0, len(bm.all()), "should match")
6371
})
6472

6573
t.Run("failure test", func(t *testing.T) {

internal/client/udp_conn.go

Lines changed: 75 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@ import (
1717
)
1818

1919
const (
20-
maxReadQueueSize = 1024
21-
permRefreshInterval = 120 * time.Second
22-
maxRetryAttempts = 3
20+
maxReadQueueSize = 1024
21+
permRefreshInterval = 120 * time.Second
22+
bindingRefreshInterval = 5 * time.Minute
23+
bindingCheckInterval = 30 * time.Second
24+
maxRetryAttempts = 3
2325
)
2426

2527
const (
2628
timerIDRefreshAlloc int = iota
2729
timerIDRefreshPerms
30+
timerIDCheckBindings
2831
)
2932

3033
type inboundData struct {
@@ -35,9 +38,10 @@ type inboundData struct {
3538
// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections.
3639
// compatible with net.PacketConn and net.Conn.
3740
type UDPConn struct {
38-
bindingMgr *bindingManager // Thread-safe
39-
readCh chan *inboundData // Thread-safe
40-
closeCh chan struct{} // Thread-safe
41+
bindingMgr *bindingManager // Thread-safe
42+
checkBindingsTimer *PeriodicTimer // Thread-safe
43+
readCh chan *inboundData // Thread-safe
44+
closeCh chan struct{} // Thread-safe
4145
allocation
4246
}
4347

@@ -77,12 +81,25 @@ func NewUDPConn(config *AllocationConfig) *UDPConn {
7781
permRefreshInterval,
7882
)
7983

84+
conn.checkBindingsTimer = NewPeriodicTimer(
85+
timerIDCheckBindings,
86+
func(timerID int) {
87+
for _, bound := range conn.bindingMgr.all() {
88+
conn.maybeBind(bound)
89+
}
90+
},
91+
bindingCheckInterval,
92+
)
93+
8094
if conn.refreshAllocTimer.Start() {
8195
conn.log.Debugf("Started refresh allocation timer")
8296
}
8397
if conn.refreshPermsTimer.Start() {
8498
conn.log.Debugf("Started refresh permission timer")
8599
}
100+
if conn.checkBindingsTimer.Start() {
101+
conn.log.Debugf("Started check bindings timer")
102+
}
86103

87104
return conn
88105
}
@@ -185,31 +202,11 @@ func (c *UDPConn) WriteTo(payload []byte, addr net.Addr) (int, error) { //nolint
185202
bound = c.bindingMgr.create(addr)
186203
}
187204

188-
bindSt := bound.state()
189-
190205
//nolint:nestif
191-
if bindSt == bindingStateIdle || bindSt == bindingStateRequest || bindSt == bindingStateFailed {
192-
func() {
193-
// Block only callers with the same binding until
194-
// the binding transaction has been complete
195-
bound.muBind.Lock()
196-
defer bound.muBind.Unlock()
197-
198-
// Binding state may have been changed while waiting. check again.
199-
if bound.state() == bindingStateIdle {
200-
bound.setState(bindingStateRequest)
201-
go func() {
202-
err2 := c.bind(bound)
203-
if err2 != nil {
204-
c.log.Warnf("Failed to bind bind(): %s", err2)
205-
bound.setState(bindingStateFailed)
206-
// Keep going...
207-
} else {
208-
bound.setState(bindingStateReady)
209-
}
210-
}()
211-
}
212-
}()
206+
if !bound.ok() {
207+
// Try to establish an initial binding with the server.
208+
// Writes still occur via indications meanwhile.
209+
c.maybeBind(bound)
213210

214211
// Send data using SendIndication
215212
peerAddr := addr2PeerAddress(addr)
@@ -225,34 +222,10 @@ func (c *UDPConn) WriteTo(payload []byte, addr net.Addr) (int, error) { //nolint
225222
return 0, err
226223
}
227224

228-
// Indication has no transaction (fire-and-forget)
229-
230225
return c.client.WriteTo(msg.Raw, c.serverAddr)
231226
}
232227

233-
// Binding is either ready
234-
235-
// Check if the binding needs a refresh
236-
func() {
237-
bound.muBind.Lock()
238-
defer bound.muBind.Unlock()
239-
240-
if bound.state() == bindingStateReady && time.Since(bound.refreshedAt()) > 5*time.Minute {
241-
bound.setState(bindingStateRefresh)
242-
go func() {
243-
if bindErr := c.bind(bound); bindErr != nil {
244-
c.log.Warnf("Failed to bind() for refresh: %s", bindErr)
245-
bound.setState(bindingStateFailed)
246-
// Keep going...
247-
} else {
248-
bound.setRefreshedAt(time.Now())
249-
bound.setState(bindingStateReady)
250-
}
251-
}()
252-
}
253-
}()
254-
255-
// Send via ChannelData
228+
// Binding is ready beyond this point, so send over it.
256229
_, err = c.sendChannelData(payload, bound.number)
257230
if err != nil {
258231
return 0, err
@@ -266,6 +239,7 @@ func (c *UDPConn) WriteTo(payload []byte, addr net.Addr) (int, error) { //nolint
266239
func (c *UDPConn) Close() error {
267240
c.refreshAllocTimer.Stop()
268241
c.refreshPermsTimer.Stop()
242+
c.checkBindingsTimer.Stop()
269243

270244
select {
271245
case <-c.closeCh:
@@ -420,6 +394,44 @@ func (c *UDPConn) FindAddrByChannelNumber(chNum uint16) (net.Addr, bool) {
420394
return b.addr, true
421395
}
422396

397+
func (c *UDPConn) maybeBind(bound *binding) {
398+
bind := func() {
399+
var err error
400+
for i := 0; i < maxRetryAttempts; i++ {
401+
if err = c.bind(bound); !errors.Is(err, errTryAgain) {
402+
break
403+
}
404+
}
405+
if err != nil {
406+
c.log.Warnf("Failed to bind channel %d: %s", bound.number, err)
407+
bound.setState(bindingStateFailed)
408+
409+
return
410+
}
411+
bound.setRefreshedAt(time.Now())
412+
bound.setState(bindingStateReady)
413+
}
414+
415+
// Block only callers with the same binding until
416+
// the binding transaction has been complete
417+
bound.muBind.Lock()
418+
defer bound.muBind.Unlock()
419+
420+
state := bound.state()
421+
switch {
422+
case state == bindingStateIdle:
423+
bound.setState(bindingStateRequest)
424+
case state == bindingStateReady && time.Since(bound.refreshedAt()) > bindingRefreshInterval:
425+
bound.setState(bindingStateRefresh)
426+
default:
427+
return
428+
}
429+
430+
// Establish binding with the server if eligible
431+
// with regard to cases right above.
432+
go bind()
433+
}
434+
423435
func (c *UDPConn) bind(bound *binding) error {
424436
setters := []stun.Setter{
425437
stun.TransactionID,
@@ -446,8 +458,15 @@ func (c *UDPConn) bind(bound *binding) error {
446458
}
447459

448460
res := trRes.Msg
461+
if res.Type.Class == stun.ClassErrorResponse {
462+
var code stun.ErrorCodeAttribute
463+
if err = code.GetFrom(res); err == nil {
464+
if code.Code == stun.CodeStaleNonce {
465+
c.setNonceFromMsg(res)
449466

450-
if res.Type != stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse) {
467+
return errTryAgain
468+
}
469+
}
451470
return fmt.Errorf("unexpected response type %s", res.Type) //nolint // dynamic errors
452471
}
453472

0 commit comments

Comments
 (0)