@@ -49,29 +49,33 @@ type (
49
49
50
50
// NotificationsMessageHandler is called when a (non-handshake) message is received over a notifications stream.
51
51
NotificationsMessageHandler = func (peer peer.ID , msg NotificationsMessage ) error
52
+
53
+ streamHandler = func (libp2pnetwork.Stream , peer.ID )
52
54
)
53
55
54
56
type notificationsProtocol struct {
55
57
protocolID protocol.ID
56
58
getHandshake HandshakeGetter
57
59
handshakeData * sync.Map //map[peer.ID]*handshakeData
60
+ streamHandler streamHandler
58
61
mapMu sync.RWMutex
59
62
}
60
63
61
- func (n * notificationsProtocol ) getHandshakeData (pid peer.ID ) (* handshakeData , bool ) {
64
+ func (n * notificationsProtocol ) getHandshakeData (pid peer.ID ) (handshakeData , bool ) {
62
65
data , has := n .handshakeData .Load (pid )
63
66
if ! has {
64
- return nil , false
67
+ return handshakeData {} , false
65
68
}
66
69
67
- return data .(* handshakeData ), true
70
+ return data .(handshakeData ), true
68
71
}
69
72
70
73
type handshakeData struct {
71
74
received bool
72
75
validated bool
73
76
handshake Handshake
74
77
outboundMsg NotificationsMessage
78
+ stream libp2pnetwork.Stream
75
79
}
76
80
77
81
func createDecoder (info * notificationsProtocol , handshakeDecoder HandshakeDecoder , messageDecoder MessageDecoder ) messageDecoder {
@@ -123,19 +127,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
123
127
// if we are the receiver and haven't received the handshake already, validate it
124
128
if _ , has := info .getHandshakeData (peer ); ! has {
125
129
logger .Trace ("receiver: validating handshake" , "protocol" , info .protocolID )
126
- info . handshakeData . Store ( peer , & handshakeData {
130
+ hsData := handshakeData {
127
131
validated : false ,
128
132
received : true ,
129
- })
133
+ stream : stream ,
134
+ }
135
+ info .handshakeData .Store (peer , hsData )
130
136
131
137
err := handshakeValidator (peer , hs )
132
138
if err != nil {
133
139
logger .Trace ("failed to validate handshake" , "protocol" , info .protocolID , "peer" , peer , "error" , err )
134
140
return errCannotValidateHandshake
135
141
}
136
142
137
- data , _ := info . getHandshakeData ( peer )
138
- data . validated = true
143
+ hsData . validated = true
144
+ info . handshakeData . Store ( peer , hsData )
139
145
140
146
// once validated, send back a handshake
141
147
resp , err := info .getHandshake ()
@@ -144,7 +150,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
144
150
return err
145
151
}
146
152
147
- err = s .host .writeToStream (stream , resp )
153
+ err = s .host .writeToStream (hsData . stream , resp )
148
154
if err != nil {
149
155
logger .Trace ("failed to send handshake" , "protocol" , info .protocolID , "peer" , peer , "error" , err )
150
156
return err
@@ -160,20 +166,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
160
166
if err != nil {
161
167
logger .Trace ("failed to validate handshake" , "protocol" , info .protocolID , "peer" , peer , "error" , err )
162
168
hsData .validated = false
169
+ info .handshakeData .Store (peer , hsData )
163
170
return errCannotValidateHandshake
164
171
}
165
172
166
173
hsData .validated = true
167
174
hsData .received = true
175
+ info .handshakeData .Store (peer , hsData )
176
+
168
177
logger .Trace ("sender: validated handshake" , "protocol" , info .protocolID , "peer" , peer )
169
- } else if hsData .received {
170
- return nil
171
178
}
172
179
173
180
// if we are the initiator, send the message
174
181
if hsData , has := info .getHandshakeData (peer ); has && hsData .validated && hsData .received && hsData .outboundMsg != nil {
175
182
logger .Trace ("sender: sending message" , "protocol" , info .protocolID )
176
- err := s .host .writeToStream (stream , hsData .outboundMsg )
183
+ err := s .host .writeToStream (hsData . stream , hsData .outboundMsg )
177
184
if err != nil {
178
185
logger .Debug ("failed to send message" , "protocol" , info .protocolID , "peer" , peer , "error" , err )
179
186
return err
@@ -209,6 +216,61 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
209
216
}
210
217
}
211
218
219
+ func (s * Service ) sendData (peer peer.ID , hs Handshake , info * notificationsProtocol , msg NotificationsMessage ) {
220
+ hsData , has := info .getHandshakeData (peer )
221
+ if ! has || ! hsData .received {
222
+ hsData = handshakeData {
223
+ validated : false ,
224
+ received : false ,
225
+ outboundMsg : msg ,
226
+ }
227
+
228
+ info .handshakeData .Store (peer , hsData )
229
+ logger .Trace ("sending handshake" , "protocol" , info .protocolID , "peer" , peer , "message" , hs )
230
+
231
+ stream , err := s .host .send (peer , info .protocolID , hs )
232
+ if err != nil {
233
+ logger .Trace ("failed to send message to peer" , "peer" , peer , "error" , err )
234
+ return
235
+ }
236
+
237
+ hsData .stream = stream
238
+ info .handshakeData .Store (peer , hsData )
239
+
240
+ if info .streamHandler == nil {
241
+ return
242
+ }
243
+
244
+ go info .streamHandler (stream , peer )
245
+ return
246
+ }
247
+
248
+ if s .host .messageCache != nil {
249
+ added , err := s .host .messageCache .put (peer , msg )
250
+ if err != nil {
251
+ logger .Error ("failed to add message to cache" , "peer" , peer , "error" , err )
252
+ return
253
+ }
254
+
255
+ if ! added {
256
+ return
257
+ }
258
+ }
259
+
260
+ if hsData .stream == nil {
261
+ logger .Error ("trying to send data through empty stream" , "protocol" , info .protocolID , "peer" , peer , "message" , msg )
262
+ return
263
+ }
264
+
265
+ // we've already completed the handshake with the peer, send message directly
266
+ logger .Trace ("sending message" , "protocol" , info .protocolID , "peer" , peer , "message" , msg )
267
+
268
+ err := s .host .writeToStream (hsData .stream , msg )
269
+ if err != nil {
270
+ logger .Trace ("failed to send message to peer" , "peer" , peer , "error" , err )
271
+ }
272
+ }
273
+
212
274
// gossipExcluding sends a message to each connected peer except the given peer
213
275
// Used for notifications sub-protocols to gossip a message
214
276
func (s * Service ) broadcastExcluding (info * notificationsProtocol , excluding peer.ID , msg NotificationsMessage ) {
@@ -234,35 +296,6 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
234
296
continue
235
297
}
236
298
237
- if hsData , has := info .getHandshakeData (peer ); ! has || ! hsData .received {
238
- info .handshakeData .Store (peer , & handshakeData {
239
- validated : false ,
240
- outboundMsg : msg ,
241
- })
242
-
243
- logger .Trace ("sending handshake" , "protocol" , info .protocolID , "peer" , peer , "message" , hs )
244
- err = s .host .send (peer , info .protocolID , hs )
245
- } else {
246
- if s .host .messageCache != nil {
247
- var added bool
248
- added , err = s .host .messageCache .put (peer , msg )
249
- if err != nil {
250
- logger .Error ("failed to add message to cache" , "peer" , peer , "error" , err )
251
- continue
252
- }
253
-
254
- if ! added {
255
- continue
256
- }
257
- }
258
-
259
- // we've already completed the handshake with the peer, send message directly
260
- logger .Trace ("sending message" , "protocol" , info .protocolID , "peer" , peer , "message" , msg )
261
- err = s .host .send (peer , info .protocolID , msg )
262
- }
263
-
264
- if err != nil {
265
- logger .Debug ("failed to send message to peer" , "peer" , peer , "error" , err )
266
- }
299
+ go s .sendData (peer , hs , info , msg )
267
300
}
268
301
}
0 commit comments