Skip to content

Commit 22f7269

Browse files
authored
fix(dot/network): update notificationsProtocol handshakeData to sync.Map (#1492)
1 parent 3b2ad8d commit 22f7269

File tree

7 files changed

+67
-53
lines changed

7 files changed

+67
-53
lines changed

dot/network/block_announce.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,13 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err
228228

229229
// don't need to lock here, since function is always called inside the func returned by
230230
// `createNotificationsMessageHandler` which locks the map beforehand.
231-
data, ok := np.handshakeData[peer]
231+
data, ok := np.getHandshakeData(peer)
232232
if !ok {
233-
np.handshakeData[peer] = &handshakeData{
233+
np.handshakeData.Store(peer, &handshakeData{
234234
received: true,
235235
validated: true,
236-
}
237-
data = np.handshakeData[peer]
236+
})
237+
data, _ = np.getHandshakeData(peer)
238238
}
239239

240240
data.handshake = hs

dot/network/block_announce_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package network
1818

1919
import (
2020
"math/big"
21+
"sync"
2122
"testing"
2223

2324
"github.com/ChainSafe/gossamer/dot/types"
@@ -116,10 +117,10 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) {
116117
nodeA := createTestService(t, configA)
117118
nodeA.noGossip = true
118119
nodeA.notificationsProtocols[BlockAnnounceMsgType] = &notificationsProtocol{
119-
handshakeData: make(map[peer.ID]*handshakeData),
120+
handshakeData: new(sync.Map),
120121
}
121122
testPeerID := peer.ID("noot")
122-
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData[testPeerID] = &handshakeData{}
123+
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, &handshakeData{})
123124

124125
err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{
125126
BestBlockNumber: 100,

dot/network/host_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,13 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
363363
info := nodeA.notificationsProtocols[BlockAnnounceMsgType]
364364

365365
// Set handshake data to received
366-
info.handshakeData[nodeB.host.id()] = &handshakeData{
366+
info.handshakeData.Store(nodeB.host.id(), &handshakeData{
367367
received: true,
368368
validated: true,
369-
}
369+
})
370370

371371
// Verify that handshake data exists.
372-
_, ok := info.handshakeData[nodeB.host.id()]
372+
_, ok := info.getHandshakeData(nodeB.host.id())
373373
require.True(t, ok)
374374

375375
time.Sleep(time.Second)
@@ -379,7 +379,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
379379
time.Sleep(time.Second)
380380

381381
// Verify that handshake data is cleared.
382-
_, ok = info.handshakeData[nodeB.host.id()]
382+
_, ok = info.getHandshakeData(nodeB.host.id())
383383
require.False(t, ok)
384384
}
385385

dot/network/notifications.go

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,19 @@ type (
5454
type notificationsProtocol struct {
5555
protocolID protocol.ID
5656
getHandshake HandshakeGetter
57-
handshakeData map[peer.ID]*handshakeData
57+
handshakeData *sync.Map //map[peer.ID]*handshakeData
5858
mapMu sync.RWMutex
5959
}
6060

61+
func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (*handshakeData, bool) {
62+
data, has := n.handshakeData.Load(pid)
63+
if !has {
64+
return nil, false
65+
}
66+
67+
return data.(*handshakeData), true
68+
}
69+
6170
type handshakeData struct {
6271
received bool
6372
validated bool
@@ -72,7 +81,7 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode
7281
info.mapMu.RLock()
7382
defer info.mapMu.RUnlock()
7483

75-
if hsData, has := info.handshakeData[peer]; !has || !hsData.received {
84+
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
7685
return handshakeDecoder(in)
7786
}
7887

@@ -112,12 +121,12 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
112121
defer info.mapMu.Unlock()
113122

114123
// if we are the receiver and haven't received the handshake already, validate it
115-
if _, has := info.handshakeData[peer]; !has {
124+
if _, has := info.getHandshakeData(peer); !has {
116125
logger.Trace("receiver: validating handshake", "protocol", info.protocolID)
117-
info.handshakeData[peer] = &handshakeData{
126+
info.handshakeData.Store(peer, &handshakeData{
118127
validated: false,
119128
received: true,
120-
}
129+
})
121130

122131
err := handshakeValidator(peer, hs)
123132
if err != nil {
@@ -126,7 +135,8 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
126135
return errCannotValidateHandshake
127136
}
128137

129-
info.handshakeData[peer].validated = true
138+
data, _ := info.getHandshakeData(peer)
139+
data.validated = true
130140

131141
// once validated, send back a handshake
132142
resp, err := info.getHandshake()
@@ -145,25 +155,25 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
145155
}
146156

147157
// if we are the initiator and haven't received the handshake already, validate it
148-
if hsData, has := info.handshakeData[peer]; has && !hsData.validated {
158+
if hsData, has := info.getHandshakeData(peer); has && !hsData.validated {
149159
logger.Trace("sender: validating handshake")
150160
err := handshakeValidator(peer, hs)
151161
if err != nil {
152162
logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err)
153-
info.handshakeData[peer].validated = false
163+
hsData.validated = false
154164
_ = stream.Conn().Close()
155165
return errCannotValidateHandshake
156166
}
157167

158-
info.handshakeData[peer].validated = true
159-
info.handshakeData[peer].received = true
168+
hsData.validated = true
169+
hsData.received = true
160170
logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer)
161171
} else if hsData.received {
162172
return nil
163173
}
164174

165175
// if we are the initiator, send the message
166-
if hsData, has := info.handshakeData[peer]; has && hsData.validated && hsData.received && hsData.outboundMsg != nil {
176+
if hsData, has := info.getHandshakeData(peer); has && hsData.validated && hsData.received && hsData.outboundMsg != nil {
167177
logger.Trace("sender: sending message", "protocol", info.protocolID)
168178
err := s.host.send(peer, info.protocolID, hsData.outboundMsg)
169179
if err != nil {
@@ -223,11 +233,11 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
223233
info.mapMu.RLock()
224234
defer info.mapMu.RUnlock()
225235

226-
if hsData, has := info.handshakeData[peer]; !has || !hsData.received {
227-
info.handshakeData[peer] = &handshakeData{
236+
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
237+
info.handshakeData.Store(peer, &handshakeData{
228238
validated: false,
229239
outboundMsg: msg,
230-
}
240+
})
231241

232242
logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
233243
err = s.host.send(peer, info.protocolID, hs)

dot/network/notifications_test.go

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package network
1818

1919
import (
2020
"math/big"
21+
"sync"
2122
"testing"
2223
"time"
2324

@@ -46,15 +47,15 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
4647
info := &notificationsProtocol{
4748
protocolID: s.host.protocolID + blockAnnounceID,
4849
getHandshake: s.getBlockAnnounceHandshake,
49-
handshakeData: make(map[peer.ID]*handshakeData),
50+
handshakeData: new(sync.Map),
5051
}
5152
decoder := createDecoder(info, decodeBlockAnnounceHandshake, decodeBlockAnnounceMessage)
5253

5354
// haven't received handshake from peer
5455
testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ")
55-
info.handshakeData[testPeerID] = &handshakeData{
56+
info.handshakeData.Store(testPeerID, &handshakeData{
5657
received: false,
57-
}
58+
})
5859

5960
testHandshake := &BlockAnnounceHandshake{
6061
Roles: 4,
@@ -82,7 +83,8 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
8283
require.NoError(t, err)
8384

8485
// set handshake data to received
85-
info.handshakeData[testPeerID].received = true
86+
hsData, _ := info.getHandshakeData(testPeerID)
87+
hsData.received = true
8688
msg, err = decoder(enc, testPeerID)
8789
require.NoError(t, err)
8890
require.Equal(t, testBlockAnnounce, msg)
@@ -132,15 +134,15 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) {
132134
info := &notificationsProtocol{
133135
protocolID: s.host.protocolID + blockAnnounceID,
134136
getHandshake: s.getBlockAnnounceHandshake,
135-
handshakeData: make(map[peer.ID]*handshakeData),
137+
handshakeData: new(sync.Map),
136138
}
137139
handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage)
138140

139141
// set handshake data to received
140-
info.handshakeData[testPeerID] = &handshakeData{
142+
info.handshakeData.Store(testPeerID, &handshakeData{
141143
received: true,
142144
validated: true,
143-
}
145+
})
144146
msg := &BlockAnnounceMessage{
145147
Number: big.NewInt(10),
146148
}
@@ -164,7 +166,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
164166
info := &notificationsProtocol{
165167
protocolID: s.host.protocolID + blockAnnounceID,
166168
getHandshake: s.getBlockAnnounceHandshake,
167-
handshakeData: make(map[peer.ID]*handshakeData),
169+
handshakeData: new(sync.Map),
168170
}
169171
handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage)
170172

@@ -205,8 +207,10 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
205207

206208
err = handler(stream, testHandshake)
207209
require.Equal(t, errCannotValidateHandshake, err)
208-
require.True(t, info.handshakeData[testPeerID].received)
209-
require.False(t, info.handshakeData[testPeerID].validated)
210+
data, has := info.getHandshakeData(testPeerID)
211+
require.True(t, has)
212+
require.True(t, data.received)
213+
require.False(t, data.validated)
210214

211215
// try valid handshake
212216
testHandshake = &BlockAnnounceHandshake{
@@ -218,6 +222,8 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
218222

219223
err = handler(stream, testHandshake)
220224
require.NoError(t, err)
221-
require.True(t, info.handshakeData[testPeerID].received)
222-
require.True(t, info.handshakeData[testPeerID].validated)
225+
data, has = info.getHandshakeData(testPeerID)
226+
require.True(t, has)
227+
require.True(t, data.received)
228+
require.True(t, data.validated)
223229
}

dot/network/service.go

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,10 @@ func (s *Service) handleConn(conn libp2pnetwork.Conn) {
304304
defer info.mapMu.RUnlock()
305305

306306
peer := conn.RemotePeer()
307-
if hsData, has := info.handshakeData[peer]; !has || !hsData.received {
308-
info.handshakeData[peer] = &handshakeData{
307+
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
308+
info.handshakeData.Store(peer, &handshakeData{
309309
validated: false,
310-
}
310+
})
311311

312312
logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
313313
err = s.host.send(peer, info.protocolID, hs)
@@ -407,7 +407,7 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID,
407407
np := &notificationsProtocol{
408408
protocolID: protocolID,
409409
getHandshake: handshakeGetter,
410-
handshakeData: make(map[peer.ID]*handshakeData),
410+
handshakeData: new(sync.Map),
411411
}
412412
s.notificationsProtocols[messageID] = np
413413

@@ -416,13 +416,13 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID,
416416
np.mapMu.Lock()
417417
defer np.mapMu.Unlock()
418418

419-
if _, ok := np.handshakeData[peerID]; ok {
419+
if _, ok := np.getHandshakeData(peerID); ok {
420420
logger.Trace(
421421
"Cleaning up handshake data",
422422
"peer", peerID,
423423
"protocol", protocolID,
424424
)
425-
delete(np.handshakeData, peerID)
425+
np.handshakeData.Delete(peerID)
426426
}
427427
})
428428

@@ -625,31 +625,28 @@ func (s *Service) Peers() []common.PeerInfo {
625625
peers := []common.PeerInfo{}
626626

627627
s.notificationsMu.RLock()
628-
defer s.notificationsMu.RUnlock()
628+
np := s.notificationsProtocols[BlockAnnounceMsgType]
629+
s.notificationsMu.RUnlock()
629630

630631
for _, p := range s.host.peers() {
631-
if s.notificationsProtocols[BlockAnnounceMsgType].handshakeData[p] == nil {
632+
data, has := np.getHandshakeData(p)
633+
if !has || data.handshake == nil {
632634
peers = append(peers, common.PeerInfo{
633635
PeerID: p.String(),
634636
})
635637

636638
continue
637639
}
638-
peerHandshakeMessage := s.notificationsProtocols[BlockAnnounceMsgType].handshakeData[p].handshake
639-
if peerHandshakeMessage == nil {
640-
peers = append(peers, common.PeerInfo{
641-
PeerID: p.String(),
642-
})
643-
continue
644-
}
645640

641+
peerHandshakeMessage := data.handshake
646642
peers = append(peers, common.PeerInfo{
647643
PeerID: p.String(),
648644
Roles: peerHandshakeMessage.(*BlockAnnounceHandshake).Roles,
649645
BestHash: peerHandshakeMessage.(*BlockAnnounceHandshake).BestBlockHash,
650646
BestNumber: uint64(peerHandshakeMessage.(*BlockAnnounceHandshake).BestBlockNumber),
651647
})
652648
}
649+
653650
return peers
654651
}
655652

dot/network/service_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,13 @@ func TestHandleConn(t *testing.T) {
371371
require.Equal(t, 1, aScore)
372372

373373
infoA := nodeA.notificationsProtocols[BlockAnnounceMsgType]
374-
hsDataB, has := infoA.handshakeData[nodeB.host.id()]
374+
hsDataB, has := infoA.getHandshakeData(nodeB.host.id())
375375
require.True(t, has)
376376
require.True(t, hsDataB.received)
377377
require.True(t, hsDataB.validated)
378378

379379
infoB := nodeB.notificationsProtocols[BlockAnnounceMsgType]
380-
hsDataA, has := infoB.handshakeData[nodeA.host.id()]
380+
hsDataA, has := infoB.getHandshakeData(nodeA.host.id())
381381
require.True(t, has)
382382
require.True(t, hsDataA.received)
383383
require.True(t, hsDataA.validated)

0 commit comments

Comments
 (0)