Skip to content

Commit 1cc671c

Browse files
authored
refactor: replace bufio.Reader with conn.TimeoutReader in tunnelHandshake for improved timeout handling
1 parent 52569d7 commit 1cc671c

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

internal/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"syscall"
1313
"time"
1414

15+
"github.com/NodePassProject/conn"
1516
"github.com/NodePassProject/logs"
1617
"github.com/NodePassProject/pool"
1718
)
@@ -136,7 +137,7 @@ func (c *Client) tunnelHandshake() error {
136137
}
137138

138139
c.tunnelTCPConn = tunnelTCPConn.(*net.TCPConn)
139-
c.bufReader = bufio.NewReader(c.tunnelTCPConn)
140+
c.bufReader = bufio.NewReader(&conn.TimeoutReader{Conn: c.tunnelTCPConn, Timeout: tcpReadTimeout})
140141
c.tunnelTCPConn.SetKeepAlive(true)
141142
c.tunnelTCPConn.SetKeepAlivePeriod(reportInterval)
142143

internal/common.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,11 @@ func (c *Common) shutdown(ctx context.Context, stopFunc func()) error {
312312

313313
// commonControl 共用控制逻辑
314314
func (c *Common) commonControl() error {
315-
go c.commonOnce()
315+
errChan := make(chan error, 3)
316316

317-
errChan := make(chan error, 2)
318-
319-
go func() {
320-
errChan <- c.commonQueue()
321-
}()
322-
go func() {
323-
errChan <- c.healthCheck()
324-
}()
317+
go func() { errChan <- c.commonOnce() }()
318+
go func() { errChan <- c.commonQueue() }()
319+
go func() { errChan <- c.healthCheck() }()
325320

326321
select {
327322
case <-c.ctx.Done():
@@ -358,15 +353,13 @@ func (c *Common) commonQueue() error {
358353
// healthCheck 共用健康度检查
359354
func (c *Common) healthCheck() error {
360355
flushURL := &url.URL{Fragment: "0"} // 连接池刷新信号
361-
checkURL := &url.URL{Fragment: "f"} // 健康检查信号
356+
pingURL := &url.URL{Fragment: "i"} // PING信号
362357
for {
363358
select {
364359
case <-c.ctx.Done():
365360
return c.ctx.Err()
366361
default:
367-
if !c.mu.TryLock() {
368-
continue
369-
}
362+
c.mu.Lock()
370363

371364
// 连接池健康度检查
372365
if c.tunnelPool.ErrorCount() > c.tunnelPool.Active()/2 {
@@ -382,7 +375,7 @@ func (c *Common) healthCheck() error {
382375
} else {
383376
// 发送普通心跳包
384377
c.checkPoint = time.Now()
385-
_, err := c.tunnelTCPConn.Write(append(c.xor([]byte(checkURL.String())), '\n'))
378+
_, err := c.tunnelTCPConn.Write(append(c.xor([]byte(pingURL.String())), '\n'))
386379
if err != nil {
387380
c.mu.Unlock()
388381
return err
@@ -609,7 +602,8 @@ func (c *Common) commonUDPLoop() {
609602
}
610603

611604
// commonOnce 共用处理单个请求
612-
func (c *Common) commonOnce() {
605+
func (c *Common) commonOnce() error {
606+
pongURL := &url.URL{Fragment: "o"} // PONG信号
613607
for {
614608
// 等待连接池准备就绪
615609
if !c.tunnelPool.Ready() {
@@ -619,13 +613,12 @@ func (c *Common) commonOnce() {
619613

620614
select {
621615
case <-c.ctx.Done():
622-
return
616+
return c.ctx.Err()
623617
case signal := <-c.signalChan:
624618
// 解析信号URL
625619
signalURL, err := url.Parse(signal)
626620
if err != nil {
627-
c.logger.Error("Parse failed: %v", err)
628-
continue
621+
return err
629622
}
630623

631624
// 处理信号
@@ -640,7 +633,14 @@ func (c *Common) commonOnce() {
640633
go c.commonTCPOnce(signalURL.Host)
641634
case "2": // UDP
642635
go c.commonUDPOnce(signalURL)
643-
case "f": // 健康检查
636+
case "i": // PING
637+
c.mu.Lock()
638+
_, err := c.tunnelTCPConn.Write(append(c.xor([]byte(pongURL.String())), '\n'))
639+
c.mu.Unlock()
640+
if err != nil {
641+
return err
642+
}
643+
case "o": // PONG
644644
c.logger.Event("HEALTH_CHECKS|POOL=%v|PING=%vms", c.tunnelPool.Active(), time.Since(c.checkPoint).Milliseconds())
645645
default:
646646
// 无效信号

internal/server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"syscall"
1414
"time"
1515

16+
"github.com/NodePassProject/conn"
1617
"github.com/NodePassProject/logs"
1718
"github.com/NodePassProject/pool"
1819
)
@@ -147,7 +148,7 @@ func (s *Server) tunnelHandshake() error {
147148
continue
148149
} else {
149150
s.tunnelTCPConn = tunnelTCPConn.(*net.TCPConn)
150-
s.bufReader = bufio.NewReader(s.tunnelTCPConn)
151+
s.bufReader = bufio.NewReader(&conn.TimeoutReader{Conn: s.tunnelTCPConn, Timeout: tcpReadTimeout})
151152
s.tunnelTCPConn.SetKeepAlive(true)
152153
s.tunnelTCPConn.SetKeepAlivePeriod(reportInterval)
153154

0 commit comments

Comments
 (0)