Skip to content

Commit a632dc4

Browse files
authored
fix(dot/network): Fix notification handshake and reuse stream. (#1545)
1 parent 6fd2501 commit a632dc4

File tree

12 files changed

+158
-117
lines changed

12 files changed

+158
-117
lines changed

dot/network/block_announce.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,15 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err
222222
// `createNotificationsMessageHandler` which locks the map beforehand.
223223
data, ok := np.getHandshakeData(peer)
224224
if !ok {
225-
np.handshakeData.Store(peer, &handshakeData{
225+
np.handshakeData.Store(peer, handshakeData{
226226
received: true,
227227
validated: true,
228228
})
229229
data, _ = np.getHandshakeData(peer)
230230
}
231231

232232
data.handshake = hs
233+
np.handshakeData.Store(peer, data)
233234

234235
// if peer has higher best block than us, begin syncing
235236
latestHeader, err := s.blockState.BestBlockHeader()

dot/network/block_announce_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) {
120120
handshakeData: new(sync.Map),
121121
}
122122
testPeerID := peer.ID("noot")
123-
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, &handshakeData{})
123+
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, handshakeData{})
124124

125125
err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{
126126
BestBlockNumber: 100,

dot/network/gossip_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func TestGossip(t *testing.T) {
101101
}
102102
require.NoError(t, err)
103103

104-
err = nodeA.host.send(addrInfosB[0].ID, "", testBlockAnnounceMessage)
104+
_, err = nodeA.host.send(addrInfosB[0].ID, "", testBlockAnnounceMessage)
105105
require.NoError(t, err)
106106

107107
time.Sleep(TestMessageTimeout)

dot/network/host.go

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -262,32 +262,26 @@ func (h *host) bootstrap() {
262262
}
263263
}
264264

265-
// send writes the given message to the outbound message stream for the given
266-
// peer (gets the already opened outbound message stream or opens a new one).
267-
func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (err error) {
268-
// get outbound stream for given peer
269-
s := h.getOutboundStream(p, pid)
270-
271-
// check if stream needs to be opened
272-
if s == nil {
273-
// open outbound stream with host protocol id
274-
s, err = h.h.NewStream(h.ctx, p, pid)
275-
if err != nil {
276-
logger.Trace("failed to open new stream with peer", "peer", p, "protocol", pid, "error", err)
277-
return err
278-
}
279-
280-
logger.Trace(
281-
"Opened stream",
282-
"host", h.id(),
283-
"peer", p,
284-
"protocol", pid,
285-
)
265+
// send creates a new outbound stream with the given peer and writes the message. It also returns
266+
// the newly created stream.
267+
func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (libp2pnetwork.Stream, error) {
268+
// open outbound stream with host protocol id
269+
stream, err := h.h.NewStream(h.ctx, p, pid)
270+
if err != nil {
271+
logger.Trace("failed to open new stream with peer", "peer", p, "protocol", pid, "error", err)
272+
return nil, err
286273
}
287274

288-
err = h.writeToStream(s, msg)
275+
logger.Trace(
276+
"Opened stream",
277+
"host", h.id(),
278+
"peer", p,
279+
"protocol", pid,
280+
)
281+
282+
err = h.writeToStream(stream, msg)
289283
if err != nil {
290-
return err
284+
return nil, err
291285
}
292286

293287
logger.Trace(
@@ -298,7 +292,7 @@ func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (err error) {
298292
"message", msg.String(),
299293
)
300294

301-
return nil
295+
return stream, nil
302296
}
303297

304298
func (h *host) writeToStream(s libp2pnetwork.Stream, msg Message) error {

dot/network/host_test.go

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func TestSend(t *testing.T) {
218218
}
219219
require.NoError(t, err)
220220

221-
err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
221+
_, err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
222222
require.NoError(t, err)
223223

224224
time.Sleep(TestMessageTimeout)
@@ -273,44 +273,29 @@ func TestExistingStream(t *testing.T) {
273273
}
274274
require.NoError(t, err)
275275

276-
stream := nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID)
277-
require.Nil(t, stream, "node A should not have an outbound stream")
278-
279276
// node A opens the stream to send the first message
280-
err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
277+
stream, err := nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
281278
require.NoError(t, err)
282279

283280
time.Sleep(TestMessageTimeout)
284281
require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A")
285282

286-
stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID)
287-
require.NotNil(t, stream, "node A should have an outbound stream")
288-
289283
// node A uses the stream to send a second message
290-
err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
284+
err = nodeA.host.writeToStream(stream, testBlockRequestMessage)
291285
require.NoError(t, err)
292286
require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A")
293287

294-
stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID)
295-
require.NotNil(t, stream, "node B should have an outbound stream")
296-
297288
// node B opens the stream to send the first message
298-
err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
289+
stream, err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
299290
require.NoError(t, err)
300291

301292
time.Sleep(TestMessageTimeout)
302293
require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B")
303294

304-
stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID)
305-
require.NotNil(t, stream, "node B should have an outbound stream")
306-
307295
// node B uses the stream to send a second message
308-
err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
296+
err = nodeB.host.writeToStream(stream, testBlockRequestMessage)
309297
require.NoError(t, err)
310298
require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B")
311-
312-
stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID)
313-
require.NotNil(t, stream, "node B should have an outbound stream")
314299
}
315300

316301
func TestStreamCloseMetadataCleanup(t *testing.T) {
@@ -361,13 +346,13 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
361346
}
362347

363348
// node A opens the stream to send the first message
364-
err = nodeA.host.send(nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID, testHandshake)
349+
_, err = nodeA.host.send(nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID, testHandshake)
365350
require.NoError(t, err)
366351

367352
info := nodeA.notificationsProtocols[BlockAnnounceMsgType]
368353

369354
// Set handshake data to received
370-
info.handshakeData.Store(nodeB.host.id(), &handshakeData{
355+
info.handshakeData.Store(nodeB.host.id(), handshakeData{
371356
received: true,
372357
validated: true,
373358
})

dot/network/notifications.go

Lines changed: 74 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -49,29 +49,33 @@ type (
4949

5050
// NotificationsMessageHandler is called when a (non-handshake) message is received over a notifications stream.
5151
NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) error
52+
53+
streamHandler = func(libp2pnetwork.Stream, peer.ID)
5254
)
5355

5456
type notificationsProtocol struct {
5557
protocolID protocol.ID
5658
getHandshake HandshakeGetter
5759
handshakeData *sync.Map //map[peer.ID]*handshakeData
60+
streamHandler streamHandler
5861
mapMu sync.RWMutex
5962
}
6063

61-
func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (*handshakeData, bool) {
64+
func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (handshakeData, bool) {
6265
data, has := n.handshakeData.Load(pid)
6366
if !has {
64-
return nil, false
67+
return handshakeData{}, false
6568
}
6669

67-
return data.(*handshakeData), true
70+
return data.(handshakeData), true
6871
}
6972

7073
type handshakeData struct {
7174
received bool
7275
validated bool
7376
handshake Handshake
7477
outboundMsg NotificationsMessage
78+
stream libp2pnetwork.Stream
7579
}
7680

7781
func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecoder, messageDecoder MessageDecoder) messageDecoder {
@@ -123,19 +127,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
123127
// if we are the receiver and haven't received the handshake already, validate it
124128
if _, has := info.getHandshakeData(peer); !has {
125129
logger.Trace("receiver: validating handshake", "protocol", info.protocolID)
126-
info.handshakeData.Store(peer, &handshakeData{
130+
hsData := handshakeData{
127131
validated: false,
128132
received: true,
129-
})
133+
stream: stream,
134+
}
135+
info.handshakeData.Store(peer, hsData)
130136

131137
err := handshakeValidator(peer, hs)
132138
if err != nil {
133139
logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err)
134140
return errCannotValidateHandshake
135141
}
136142

137-
data, _ := info.getHandshakeData(peer)
138-
data.validated = true
143+
hsData.validated = true
144+
info.handshakeData.Store(peer, hsData)
139145

140146
// once validated, send back a handshake
141147
resp, err := info.getHandshake()
@@ -144,7 +150,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
144150
return err
145151
}
146152

147-
err = s.host.writeToStream(stream, resp)
153+
err = s.host.writeToStream(hsData.stream, resp)
148154
if err != nil {
149155
logger.Trace("failed to send handshake", "protocol", info.protocolID, "peer", peer, "error", err)
150156
return err
@@ -160,20 +166,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
160166
if err != nil {
161167
logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err)
162168
hsData.validated = false
169+
info.handshakeData.Store(peer, hsData)
163170
return errCannotValidateHandshake
164171
}
165172

166173
hsData.validated = true
167174
hsData.received = true
175+
info.handshakeData.Store(peer, hsData)
176+
168177
logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer)
169-
} else if hsData.received {
170-
return nil
171178
}
172179

173180
// if we are the initiator, send the message
174181
if hsData, has := info.getHandshakeData(peer); has && hsData.validated && hsData.received && hsData.outboundMsg != nil {
175182
logger.Trace("sender: sending message", "protocol", info.protocolID)
176-
err := s.host.writeToStream(stream, hsData.outboundMsg)
183+
err := s.host.writeToStream(hsData.stream, hsData.outboundMsg)
177184
if err != nil {
178185
logger.Debug("failed to send message", "protocol", info.protocolID, "peer", peer, "error", err)
179186
return err
@@ -209,6 +216,61 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
209216
}
210217
}
211218

219+
func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtocol, msg NotificationsMessage) {
220+
hsData, has := info.getHandshakeData(peer)
221+
if !has || !hsData.received {
222+
hsData = handshakeData{
223+
validated: false,
224+
received: false,
225+
outboundMsg: msg,
226+
}
227+
228+
info.handshakeData.Store(peer, hsData)
229+
logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
230+
231+
stream, err := s.host.send(peer, info.protocolID, hs)
232+
if err != nil {
233+
logger.Trace("failed to send message to peer", "peer", peer, "error", err)
234+
return
235+
}
236+
237+
hsData.stream = stream
238+
info.handshakeData.Store(peer, hsData)
239+
240+
if info.streamHandler == nil {
241+
return
242+
}
243+
244+
go info.streamHandler(stream, peer)
245+
return
246+
}
247+
248+
if s.host.messageCache != nil {
249+
added, err := s.host.messageCache.put(peer, msg)
250+
if err != nil {
251+
logger.Error("failed to add message to cache", "peer", peer, "error", err)
252+
return
253+
}
254+
255+
if !added {
256+
return
257+
}
258+
}
259+
260+
if hsData.stream == nil {
261+
logger.Error("trying to send data through empty stream", "protocol", info.protocolID, "peer", peer, "message", msg)
262+
return
263+
}
264+
265+
// we've already completed the handshake with the peer, send message directly
266+
logger.Trace("sending message", "protocol", info.protocolID, "peer", peer, "message", msg)
267+
268+
err := s.host.writeToStream(hsData.stream, msg)
269+
if err != nil {
270+
logger.Trace("failed to send message to peer", "peer", peer, "error", err)
271+
}
272+
}
273+
212274
// gossipExcluding sends a message to each connected peer except the given peer
213275
// Used for notifications sub-protocols to gossip a message
214276
func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer.ID, msg NotificationsMessage) {
@@ -234,35 +296,6 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
234296
continue
235297
}
236298

237-
if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
238-
info.handshakeData.Store(peer, &handshakeData{
239-
validated: false,
240-
outboundMsg: msg,
241-
})
242-
243-
logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
244-
err = s.host.send(peer, info.protocolID, hs)
245-
} else {
246-
if s.host.messageCache != nil {
247-
var added bool
248-
added, err = s.host.messageCache.put(peer, msg)
249-
if err != nil {
250-
logger.Error("failed to add message to cache", "peer", peer, "error", err)
251-
continue
252-
}
253-
254-
if !added {
255-
continue
256-
}
257-
}
258-
259-
// we've already completed the handshake with the peer, send message directly
260-
logger.Trace("sending message", "protocol", info.protocolID, "peer", peer, "message", msg)
261-
err = s.host.send(peer, info.protocolID, msg)
262-
}
263-
264-
if err != nil {
265-
logger.Debug("failed to send message to peer", "peer", peer, "error", err)
266-
}
299+
go s.sendData(peer, hs, info, msg)
267300
}
268301
}

dot/network/notifications_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
5353

5454
// haven't received handshake from peer
5555
testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ")
56-
info.handshakeData.Store(testPeerID, &handshakeData{
56+
info.handshakeData.Store(testPeerID, handshakeData{
5757
received: false,
5858
})
5959

@@ -85,6 +85,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
8585
// set handshake data to received
8686
hsData, _ := info.getHandshakeData(testPeerID)
8787
hsData.received = true
88+
info.handshakeData.Store(testPeerID, hsData)
8889
msg, err = decoder(enc, testPeerID)
8990
require.NoError(t, err)
9091
require.Equal(t, testBlockAnnounce, msg)
@@ -139,7 +140,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) {
139140
handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage)
140141

141142
// set handshake data to received
142-
info.handshakeData.Store(testPeerID, &handshakeData{
143+
info.handshakeData.Store(testPeerID, handshakeData{
143144
received: true,
144145
validated: true,
145146
})

0 commit comments

Comments
 (0)