Skip to content

Commit 186e27d

Browse files
authored
fix: update UDP connection handling to use StatConn for improved monitoring and reliability
1 parent f1407eb commit 186e27d

File tree

1 file changed

+20
-29
lines changed

1 file changed

+20
-29
lines changed

internal/common.go

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ type Common struct {
3737
targetListener *net.TCPListener // 目标监听器
3838
tunnelListener net.Listener // 隧道监听器
3939
tunnelTCPConn *net.TCPConn // 隧道TCP连接
40-
tunnelUDPConn *net.UDPConn // 隧道UDP连接
40+
tunnelUDPConn *conn.StatConn // 隧道UDP连接
4141
targetTCPConn *net.TCPConn // 目标TCP连接
42-
targetUDPConn *net.UDPConn // 目标UDP连接
42+
targetUDPConn *conn.StatConn // 目标UDP连接
4343
targetUDPSession sync.Map // 目标UDP会话
4444
tunnelPool *pool.Pool // 隧道连接池
4545
minPoolCapacity int // 最小池容量
@@ -358,22 +358,16 @@ func (c *Common) initTunnelListener() error {
358358
// 初始化隧道TCP监听器
359359
tunnelListener, err := net.ListenTCP("tcp", c.tunnelTCPAddr)
360360
if err != nil {
361-
if tunnelListener != nil {
362-
tunnelListener.Close()
363-
}
364361
return fmt.Errorf("initTunnelListener: listenTCP failed: %w", err)
365362
}
366363
c.tunnelListener = tunnelListener
367364

368365
// 初始化隧道UDP监听器
369366
tunnelUDPConn, err := net.ListenUDP("udp", c.tunnelUDPAddr)
370367
if err != nil {
371-
if tunnelUDPConn != nil {
372-
tunnelUDPConn.Close()
373-
}
374368
return fmt.Errorf("initTunnelListener: listenUDP failed: %w", err)
375369
}
376-
c.tunnelUDPConn = tunnelUDPConn
370+
c.tunnelUDPConn = &conn.StatConn{Conn: tunnelUDPConn, RX: &c.udpRX, TX: &c.udpTX, Rate: c.rateLimiter}
377371

378372
return nil
379373
}
@@ -387,21 +381,16 @@ func (c *Common) initTargetListener() error {
387381
// 初始化目标TCP监听器
388382
targetListener, err := net.ListenTCP("tcp", c.targetTCPAddr)
389383
if err != nil {
390-
if targetListener != nil {
391-
targetListener.Close()
392-
}
393384
return fmt.Errorf("initTargetListener: listenTCP failed: %w", err)
394385
}
395386
c.targetListener = targetListener
396387

397388
// 初始化目标UDP监听器
398-
var targetUDPConn net.Conn
399-
targetUDPConn, err = net.ListenUDP("udp", c.targetUDPAddr)
389+
targetUDPConn, err := net.ListenUDP("udp", c.targetUDPAddr)
400390
if err != nil {
401391
return fmt.Errorf("initTargetListener: listenUDP failed: %w", err)
402392
}
403-
targetUDPConn = &conn.StatConn{Conn: targetUDPConn, RX: &c.udpRX, TX: &c.udpTX, Rate: c.rateLimiter}
404-
c.targetUDPConn = targetUDPConn.(*net.UDPConn)
393+
c.targetUDPConn = &conn.StatConn{Conn: targetUDPConn, RX: &c.udpRX, TX: &c.udpTX, Rate: c.rateLimiter}
405394

406395
return nil
407396
}
@@ -867,7 +856,13 @@ func (c *Common) commonOnce() error {
867856
// 解析信号URL
868857
signalURL, err := url.Parse(signal)
869858
if err != nil {
870-
return fmt.Errorf("commonOnce: parse signal URL failed: %w", err)
859+
c.logger.Error("commonOnce: parse signal failed: %v", err)
860+
select {
861+
case <-c.ctx.Done():
862+
return fmt.Errorf("commonOnce: context error: %w", c.ctx.Err())
863+
case <-time.After(50 * time.Millisecond):
864+
}
865+
continue
871866
}
872867

873868
// 处理信号
@@ -1007,19 +1002,17 @@ func (c *Common) commonUDPOnce(signalURL *url.URL) {
10071002

10081003
// 获取或创建目标UDP会话
10091004
if session, ok := c.targetUDPSession.Load(sessionKey); ok {
1010-
targetConn = session.(*net.UDPConn)
1005+
targetConn = session.(net.Conn)
10111006
c.logger.Debug("Using UDP session: %v <-> %v", targetConn.LocalAddr(), targetConn.RemoteAddr())
10121007
} else {
10131008
// 创建新的会话
1014-
session, err := net.DialTimeout("udp", c.targetUDPAddr.String(), udpDialTimeout)
1009+
newSession, err := net.DialTimeout("udp", c.targetUDPAddr.String(), udpDialTimeout)
10151010
if err != nil {
10161011
c.logger.Error("commonUDPOnce: dialTimeout failed: %v", err)
10171012
return
10181013
}
1019-
c.targetUDPSession.Store(sessionKey, session)
1020-
1021-
targetConn = &conn.StatConn{Conn: targetConn, RX: &c.udpRX, TX: &c.udpTX, Rate: c.rateLimiter}
1022-
targetConn = session.(*net.UDPConn)
1014+
targetConn = &conn.StatConn{Conn: newSession, RX: &c.udpRX, TX: &c.udpTX, Rate: c.rateLimiter}
1015+
c.targetUDPSession.Store(sessionKey, targetConn)
10231016
c.logger.Debug("Target connection: %v <-> %v", targetConn.LocalAddr(), targetConn.RemoteAddr())
10241017
}
10251018
c.logger.Debug("Starting transfer: %v <-> %v", remoteConn.LocalAddr(), targetConn.LocalAddr())
@@ -1265,7 +1258,7 @@ func (c *Common) singleUDPLoop() error {
12651258
// 获取或创建目标UDP会话
12661259
if session, ok := c.targetUDPSession.Load(sessionKey); ok {
12671260
// 复用现有会话
1268-
targetConn = session.(*net.UDPConn)
1261+
targetConn = session.(net.Conn)
12691262
c.logger.Debug("Using UDP session: %v <-> %v", targetConn.LocalAddr(), targetConn.RemoteAddr())
12701263
} else {
12711264
// 尝试获取UDP连接槽位
@@ -1276,17 +1269,15 @@ func (c *Common) singleUDPLoop() error {
12761269
}
12771270

12781271
// 创建新的会话
1279-
session, err := net.DialTimeout("udp", c.targetUDPAddr.String(), udpDialTimeout)
1272+
newSession, err := net.DialTimeout("udp", c.targetUDPAddr.String(), udpDialTimeout)
12801273
if err != nil {
12811274
c.logger.Error("singleUDPLoop: dialTimeout failed: %v", err)
12821275
c.releaseSlot(true)
12831276
putUDPBuffer(buffer)
12841277
continue
12851278
}
1286-
c.targetUDPSession.Store(sessionKey, session)
1287-
1288-
targetConn = &conn.StatConn{Conn: targetConn, RX: &c.udpRX, TX: &c.udpTX, Rate: c.rateLimiter}
1289-
targetConn = session.(*net.UDPConn)
1279+
targetConn = newSession
1280+
c.targetUDPSession.Store(sessionKey, newSession)
12901281
c.logger.Debug("Target connection: %v <-> %v", targetConn.LocalAddr(), targetConn.RemoteAddr())
12911282

12921283
go func(targetConn net.Conn, clientAddr *net.UDPAddr, sessionKey string) {

0 commit comments

Comments
 (0)