Skip to content

Commit ba861bf

Browse files
authored
feat(dot/network): implement streamManager to cleanup not recently used streams (#1611)
1 parent dd3838c commit ba861bf

File tree

4 files changed

+197
-11
lines changed

4 files changed

+197
-11
lines changed

dot/network/service.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,13 @@ type Service struct {
6767
ctx context.Context
6868
cancel context.CancelFunc
6969

70-
cfg *Config
71-
host *host
72-
mdns *mdns
73-
gossip *gossip
74-
syncQueue *syncQueue
75-
bufPool *sizedBufferPool
70+
cfg *Config
71+
host *host
72+
mdns *mdns
73+
gossip *gossip
74+
syncQueue *syncQueue
75+
bufPool *sizedBufferPool
76+
streamManager *streamManager
7677

7778
notificationsProtocols map[byte]*notificationsProtocol // map of sub-protocol msg ID to protocol info
7879
notificationsMu sync.RWMutex
@@ -162,6 +163,7 @@ func NewService(cfg *Config) (*Service, error) {
162163
telemetryInterval: cfg.telemetryInterval,
163164
closeCh: make(chan interface{}),
164165
bufPool: bufPool,
166+
streamManager: newStreamManager(ctx),
165167
}
166168

167169
network.syncQueue = newSyncQueue(network)
@@ -267,6 +269,7 @@ func (s *Service) Start() error {
267269
go s.logPeerCount()
268270
go s.publishNetworkTelemetry(s.closeCh)
269271
go s.sentBlockIntervalTelemetry()
272+
s.streamManager.start()
270273

271274
return nil
272275
}
@@ -529,6 +532,8 @@ func isInbound(stream libp2pnetwork.Stream) bool {
529532
}
530533

531534
func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder, handler messageHandler) {
535+
s.streamManager.logNewStream(stream)
536+
532537
peer := stream.Conn().RemotePeer()
533538
msgBytes := s.bufPool.get()
534539
defer s.bufPool.put(&msgBytes)
@@ -543,6 +548,8 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder
543548
return
544549
}
545550

551+
s.streamManager.logMessageReceived(stream.ID())
552+
546553
// decode message based on message type
547554
msg, err := decoder(msgBytes[:tot], peer, isInbound(stream))
548555
if err != nil {

dot/network/stream_manager.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package network
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"github.com/libp2p/go-libp2p-core/network"
9+
)
10+
11+
var cleanupStreamInterval = time.Minute
12+
13+
type streamData struct {
14+
lastReceivedMessage time.Time
15+
stream network.Stream
16+
}
17+
18+
// streamManager tracks inbound streams and runs a cleanup goroutine every `cleanupStreamInterval` to close streams that
19+
// we haven't received any data on for the last time period. this prevents keeping stale streams open and continuously trying to
20+
// read from it, which takes up lots of CPU over time.
21+
type streamManager struct {
22+
ctx context.Context
23+
streamDataMap *sync.Map //map[string]*streamData
24+
}
25+
26+
func newStreamManager(ctx context.Context) *streamManager {
27+
return &streamManager{
28+
ctx: ctx,
29+
streamDataMap: new(sync.Map),
30+
}
31+
}
32+
33+
func (sm *streamManager) start() {
34+
go func() {
35+
ticker := time.NewTicker(cleanupStreamInterval)
36+
defer ticker.Stop()
37+
38+
for {
39+
select {
40+
case <-sm.ctx.Done():
41+
return
42+
case <-ticker.C:
43+
sm.cleanupStreams()
44+
}
45+
}
46+
}()
47+
}
48+
49+
func (sm *streamManager) cleanupStreams() {
50+
sm.streamDataMap.Range(func(id, data interface{}) bool {
51+
sdata := data.(*streamData)
52+
lastReceived := sdata.lastReceivedMessage
53+
stream := sdata.stream
54+
55+
if time.Since(lastReceived) > cleanupStreamInterval {
56+
_ = stream.Close()
57+
sm.streamDataMap.Delete(id)
58+
}
59+
60+
return true
61+
})
62+
}
63+
64+
func (sm *streamManager) logNewStream(stream network.Stream) {
65+
data := &streamData{
66+
lastReceivedMessage: time.Now(), // prevents closing just opened streams, in case the cleanup goroutine runs at the same time stream is opened
67+
stream: stream,
68+
}
69+
sm.streamDataMap.Store(stream.ID(), data)
70+
}
71+
72+
func (sm *streamManager) logMessageReceived(streamID string) {
73+
data, has := sm.streamDataMap.Load(streamID)
74+
if !has {
75+
return
76+
}
77+
78+
sdata := data.(*streamData)
79+
sdata.lastReceivedMessage = time.Now()
80+
sm.streamDataMap.Store(streamID, sdata)
81+
}

dot/network/stream_manager_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package network
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"testing"
7+
"time"
8+
9+
"github.com/libp2p/go-libp2p"
10+
libp2phost "github.com/libp2p/go-libp2p-core/host"
11+
"github.com/libp2p/go-libp2p-core/network"
12+
"github.com/libp2p/go-libp2p-core/peer"
13+
ma "github.com/multiformats/go-multiaddr"
14+
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
func setupStreamManagerTest(t *testing.T) (context.Context, []libp2phost.Host, []*streamManager) {
19+
ctx, cancel := context.WithCancel(context.Background())
20+
21+
cleanupStreamInterval = time.Millisecond * 500
22+
t.Cleanup(func() {
23+
cleanupStreamInterval = time.Minute
24+
cancel()
25+
})
26+
27+
smA := newStreamManager(ctx)
28+
smB := newStreamManager(ctx)
29+
30+
portA := 7001
31+
portB := 7002
32+
addrA, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", portA))
33+
require.NoError(t, err)
34+
addrB, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", portB))
35+
require.NoError(t, err)
36+
37+
ha, err := libp2p.New(
38+
ctx, libp2p.ListenAddrs(addrA),
39+
)
40+
require.NoError(t, err)
41+
42+
hb, err := libp2p.New(
43+
ctx, libp2p.ListenAddrs(addrB),
44+
)
45+
require.NoError(t, err)
46+
47+
err = ha.Connect(ctx, peer.AddrInfo{
48+
ID: hb.ID(),
49+
Addrs: hb.Addrs(),
50+
})
51+
require.NoError(t, err)
52+
53+
hb.SetStreamHandler("", func(stream network.Stream) {
54+
smB.logNewStream(stream)
55+
})
56+
57+
return ctx, []libp2phost.Host{ha, hb}, []*streamManager{smA, smB}
58+
}
59+
60+
func TestStreamManager(t *testing.T) {
61+
ctx, hosts, sms := setupStreamManagerTest(t)
62+
ha, hb := hosts[0], hosts[1]
63+
smA, smB := sms[0], sms[1]
64+
65+
stream, err := ha.NewStream(ctx, hb.ID(), "")
66+
require.NoError(t, err)
67+
68+
smA.logNewStream(stream)
69+
smA.start()
70+
smB.start()
71+
72+
time.Sleep(cleanupStreamInterval * 2)
73+
connsAToB := ha.Network().ConnsToPeer(hb.ID())
74+
require.Equal(t, 1, len(connsAToB))
75+
require.Equal(t, 0, len(connsAToB[0].GetStreams()))
76+
77+
connsBToA := hb.Network().ConnsToPeer(ha.ID())
78+
require.Equal(t, 1, len(connsBToA))
79+
require.Equal(t, 0, len(connsBToA[0].GetStreams()))
80+
}
81+
82+
func TestStreamManager_KeepStream(t *testing.T) {
83+
ctx, hosts, sms := setupStreamManagerTest(t)
84+
ha, hb := hosts[0], hosts[1]
85+
smA, smB := sms[0], sms[1]
86+
87+
stream, err := ha.NewStream(ctx, hb.ID(), "")
88+
require.NoError(t, err)
89+
90+
smA.logNewStream(stream)
91+
smA.start()
92+
smB.start()
93+
94+
time.Sleep(cleanupStreamInterval / 2)
95+
connsAToB := ha.Network().ConnsToPeer(hb.ID())
96+
require.Equal(t, 1, len(connsAToB))
97+
require.Equal(t, 1, len(connsAToB[0].GetStreams()))
98+
99+
connsBToA := hb.Network().ConnsToPeer(ha.ID())
100+
require.Equal(t, 1, len(connsBToA))
101+
require.Equal(t, 1, len(connsBToA[0].GetStreams()))
102+
}

dot/network/utils.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
184184
)
185185

186186
length, err := readLEB128ToUint64(stream, buf[:1])
187-
if err == io.EOF {
188-
return 0, err
189-
} else if err != nil {
187+
if err != nil {
190188
return 0, err // TODO: return bytes read from readLEB128ToUint64
191189
}
192190

@@ -196,13 +194,11 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
196194

197195
if length > uint64(len(buf)) {
198196
logger.Warn("received message with size greater than allocated message buffer", "length", length, "buffer size", len(buf))
199-
_ = stream.Close()
200197
return 0, fmt.Errorf("message size greater than allocated message buffer: got %d", length)
201198
}
202199

203200
if length > maxBlockResponseSize {
204201
logger.Warn("received message with size greater than maxBlockResponseSize, closing stream", "length", length)
205-
_ = stream.Close()
206202
return 0, fmt.Errorf("message size greater than maximum: got %d", length)
207203
}
208204

0 commit comments

Comments
 (0)