Skip to content

Commit 8f01761

Browse files
authored
refactor: improve UDP session handling and error logging in commonUDPLoop and commonUDPOnce
1 parent 9b9f7f5 commit 8f01761

File tree

1 file changed

+156
-101
lines changed

1 file changed

+156
-101
lines changed

internal/common.go

Lines changed: 156 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -467,36 +467,85 @@ func (c *Common) commonUDPLoop() {
467467
case <-c.ctx.Done():
468468
return
469469
default:
470-
// 读取来自目标的UDP数据
471470
buffer := make([]byte, udpDataBufSize)
471+
472+
// 读取来自目标的UDP数据
472473
n, clientAddr, err := c.targetUDPConn.ReadFromUDP(buffer)
473474
if err != nil {
474475
continue
475476
}
476477

477478
c.logger.Debug("Target connection: %v <-> %v", c.targetUDPConn.LocalAddr(), clientAddr)
478479

479-
// 从连接池获取连接
480-
id, remoteConn := c.tunnelPool.ServerGet()
481-
if remoteConn == nil {
482-
c.logger.Error("Get failed: %v not found", id)
483-
c.tunnelPool.AddError()
484-
continue
485-
}
480+
var id string
481+
var remoteConn net.Conn
482+
sessionKey := clientAddr.String()
486483

487-
c.logger.Debug("Tunnel connection: get %v <- pool active %v", id, c.tunnelPool.Active())
484+
// 获取或创建UDP会话
485+
if session, ok := c.targetUDPSession.Load(sessionKey); ok {
486+
// 复用现有会话
487+
remoteConn = session.(net.Conn)
488+
c.logger.Debug("Using UDP session: %v <-> %v", remoteConn.LocalAddr(), remoteConn.RemoteAddr())
489+
} else {
490+
// 获取池连接
491+
id, remoteConn = c.tunnelPool.ServerGet()
492+
if remoteConn == nil {
493+
c.logger.Error("Get failed: %v not found", id)
494+
c.tunnelPool.AddError()
495+
continue
496+
}
497+
c.targetUDPSession.Store(sessionKey, remoteConn)
498+
c.logger.Debug("Tunnel connection: get %v <- pool active %v", id, c.tunnelPool.Active())
499+
c.logger.Debug("Tunnel connection: %v <-> %v", remoteConn.LocalAddr(), remoteConn.RemoteAddr())
488500

489-
c.logger.Debug("Tunnel connection: %v <-> %v", remoteConn.LocalAddr(), remoteConn.RemoteAddr())
501+
// 使用信号量限制并发数
502+
c.semaphore <- struct{}{}
490503

491-
// 使用信号量限制并发数
492-
c.semaphore <- struct{}{}
504+
go func(remoteConn net.Conn, clientAddr *net.UDPAddr, sessionKey, id string) {
505+
defer func() {
506+
c.tunnelPool.Put(id, remoteConn)
507+
c.logger.Debug("Tunnel connection: put %v -> pool active %v", id, c.tunnelPool.Active())
508+
c.targetUDPSession.Delete(sessionKey)
509+
<-c.semaphore
510+
}()
493511

494-
go func(buffer []byte, n int, clientAddr *net.UDPAddr, remoteConn net.Conn, id string) {
495-
defer func() {
496-
c.tunnelPool.Put(id, remoteConn)
497-
c.logger.Debug("Tunnel connection: put %v -> pool active %v", id, c.tunnelPool.Active())
498-
<-c.semaphore
499-
}()
512+
buffer := make([]byte, udpDataBufSize)
513+
514+
for {
515+
select {
516+
case <-c.ctx.Done():
517+
return
518+
default:
519+
// 设置TCP读取超时
520+
if err := remoteConn.SetReadDeadline(time.Now().Add(udpReadTimeout)); err != nil {
521+
c.logger.Error("SetReadDeadline failed: %v", err)
522+
return
523+
}
524+
525+
// 从池连接读取数据
526+
x, err := remoteConn.Read(buffer)
527+
if err != nil {
528+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
529+
c.logger.Debug("UDP session abort: %v", err)
530+
} else if strings.Contains(err.Error(), "use of closed network connection") {
531+
c.logger.Debug("Read closed: %v", err)
532+
} else {
533+
c.logger.Error("Read failed: %v", err)
534+
}
535+
return
536+
}
537+
538+
// 将数据写入目标UDP连接
539+
tx, err := c.targetUDPConn.WriteToUDP(buffer[:x], clientAddr)
540+
if err != nil {
541+
c.logger.Error("WriteToUDP failed: %v", err)
542+
return
543+
}
544+
// 传输完成,广播统计信息
545+
c.logger.Event("Transfer complete: TRAFFIC_STATS|TCP_RX=0|TCP_TX=0|UDP_RX=0|UDP_TX=%v", tx)
546+
}
547+
}
548+
}(remoteConn, clientAddr, sessionKey, id)
500549

501550
// 构建并发送启动URL到客户端
502551
launchURL := &url.URL{
@@ -507,46 +556,26 @@ func (c *Common) commonUDPLoop() {
507556
c.mu.Lock()
508557
_, err = c.tunnelTCPConn.Write(append(c.xor([]byte(launchURL.String())), '\n'))
509558
c.mu.Unlock()
510-
511559
if err != nil {
512560
c.logger.Error("Write failed: %v", err)
513-
return
561+
continue
514562
}
515563

516564
c.logger.Debug("UDP launch signal: pid %v -> %v", id, c.tunnelTCPConn.RemoteAddr())
517565
c.logger.Debug("Starting transfer: %v <-> %v", remoteConn.LocalAddr(), c.targetUDPConn.LocalAddr())
566+
}
518567

519-
// 处理UDP/TCP数据传输
520-
rx, err := remoteConn.Write(buffer[:n])
521-
if err != nil {
522-
c.logger.Error("Write failed: %v", err)
523-
return
524-
}
525-
526-
if err := remoteConn.SetReadDeadline(time.Now().Add(tcpReadTimeout)); err != nil {
527-
c.logger.Error("SetReadDeadline failed: %v", err)
528-
return
529-
}
530-
531-
x, err := remoteConn.Read(buffer)
532-
if err != nil {
533-
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
534-
c.logger.Debug("Read timeout: %v", err)
535-
} else {
536-
c.logger.Error("Read failed: %v", err)
537-
}
538-
return
539-
}
540-
541-
tx, err := c.targetUDPConn.WriteToUDP(buffer[:x], clientAddr)
542-
if err != nil {
543-
c.logger.Error("Write failed: %v", err)
544-
return
545-
}
568+
// 将原始数据写入池连接
569+
rx, err := remoteConn.Write(buffer[:n])
570+
if err != nil {
571+
c.logger.Error("Write failed: %v", err)
572+
c.targetUDPSession.Delete(sessionKey)
573+
remoteConn.Close()
574+
continue
575+
}
546576

547-
// 传输完成,广播统计信息
548-
c.logger.Event("Transfer complete: TRAFFIC_STATS|TCP_RX=0|TCP_TX=0|UDP_RX=%v|UDP_TX=%v", rx, tx)
549-
}(buffer, n, clientAddr, remoteConn, id)
577+
// 传输完成,广播统计信息
578+
c.logger.Event("Transfer complete: TRAFFIC_STATS|TCP_RX=0|TCP_TX=0|UDP_RX=%v|UDP_TX=0", rx)
550579
}
551580
}
552581
}
@@ -638,20 +667,13 @@ func (c *Common) commonTCPOnce(id string) {
638667
func (c *Common) commonUDPOnce(id string) {
639668
c.logger.Debug("UDP launch signal: pid %v <- %v", id, c.tunnelTCPConn.RemoteAddr())
640669

641-
// 从连接池获取连接
670+
// 先从池获取连接
642671
remoteConn := c.tunnelPool.ClientGet(id)
643672
if remoteConn == nil {
644673
c.logger.Error("Get failed: %v not found", id)
645674
return
646675
}
647-
648676
c.logger.Debug("Tunnel connection: get %v <- pool active %v", id, c.tunnelPool.Active())
649-
650-
defer func() {
651-
c.tunnelPool.Put(id, remoteConn)
652-
c.logger.Debug("Tunnel connection: put %v -> pool active %v", id, c.tunnelPool.Active())
653-
}()
654-
655677
c.logger.Debug("Tunnel connection: %v <-> %v", remoteConn.LocalAddr(), remoteConn.RemoteAddr())
656678

657679
// 连接到目标UDP地址
@@ -660,54 +682,86 @@ func (c *Common) commonUDPOnce(id string) {
660682
c.logger.Error("Dial failed: %v", err)
661683
return
662684
}
663-
664-
defer func() {
665-
if targetConn != nil {
666-
targetConn.Close()
667-
}
668-
}()
669-
670-
c.targetUDPConn = targetConn.(*net.UDPConn)
671685
c.logger.Debug("Target connection: %v <-> %v", targetConn.LocalAddr(), targetConn.RemoteAddr())
672-
c.logger.Debug("Starting transfer: %v <-> %v", remoteConn.LocalAddr(), targetConn.LocalAddr())
673686

674-
// 处理UDP/TCP数据传输
675-
buffer := make([]byte, udpDataBufSize)
676-
n, err := remoteConn.Read(buffer)
677-
if err != nil {
678-
c.logger.Error("Read failed: %v", err)
679-
return
680-
}
687+
done := make(chan struct{}, 2)
681688

682-
rx, err := c.targetUDPConn.Write(buffer[:n])
683-
if err != nil {
684-
c.logger.Error("Write failed: %v", err)
685-
return
686-
}
687-
688-
if err := c.targetUDPConn.SetReadDeadline(time.Now().Add(udpReadTimeout)); err != nil {
689-
c.logger.Error("SetReadDeadline failed: %v", err)
690-
return
691-
}
692-
693-
x, err := c.targetUDPConn.Read(buffer)
694-
if err != nil {
695-
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
696-
c.logger.Debug("Read timeout: %v", err)
697-
} else {
698-
c.logger.Error("Read failed: %v", err)
689+
go func() {
690+
defer func() {
691+
done <- struct{}{}
692+
}()
693+
buffer := make([]byte, udpDataBufSize)
694+
for {
695+
select {
696+
case <-c.ctx.Done():
697+
return
698+
default:
699+
if err := remoteConn.SetReadDeadline(time.Now().Add(udpReadTimeout)); err != nil {
700+
c.logger.Error("SetReadDeadline failed: %v", err)
701+
return
702+
}
703+
x, err := remoteConn.Read(buffer)
704+
if err != nil {
705+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
706+
c.logger.Debug("Read timeout: %v", err)
707+
} else if strings.Contains(err.Error(), "use of closed network connection") {
708+
c.logger.Debug("Read closed: %v", err)
709+
} else {
710+
c.logger.Error("Read failed: %v", err)
711+
}
712+
return
713+
}
714+
rx, err := targetConn.Write(buffer[:x])
715+
if err != nil {
716+
c.logger.Error("Write failed: %v", err)
717+
return
718+
}
719+
// 传输完成,广播统计信息
720+
c.logger.Event("Transfer complete: TRAFFIC_STATS|TCP_RX=0|TCP_TX=0|UDP_RX=%v|UDP_TX=0", rx)
721+
}
699722
}
700-
return
701-
}
723+
}()
702724

703-
tx, err := remoteConn.Write(buffer[:x])
704-
if err != nil {
705-
c.logger.Error("Write failed: %v", err)
706-
return
707-
}
725+
go func() {
726+
defer func() {
727+
done <- struct{}{}
728+
}()
729+
buffer := make([]byte, udpDataBufSize)
730+
for {
731+
select {
732+
case <-c.ctx.Done():
733+
return
734+
default:
735+
if err := targetConn.SetReadDeadline(time.Now().Add(udpReadTimeout)); err != nil {
736+
c.logger.Error("SetReadDeadline failed: %v", err)
737+
return
738+
}
739+
x, err := targetConn.Read(buffer)
740+
if err != nil {
741+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
742+
c.logger.Debug("Read timeout: %v", err)
743+
} else if strings.Contains(err.Error(), "use of closed network connection") {
744+
c.logger.Debug("Read closed: %v", err)
745+
} else {
746+
c.logger.Error("Read failed: %v", err)
747+
}
748+
return
749+
}
750+
tx, err := remoteConn.Write(buffer[:x])
751+
if err != nil {
752+
c.logger.Error("Write failed: %v", err)
753+
return
754+
}
755+
// 传输完成,广播统计信息
756+
c.logger.Event("Transfer complete: TRAFFIC_STATS|TCP_RX=0|TCP_TX=0|UDP_RX=0|UDP_TX=%v", tx)
757+
}
758+
}
759+
}()
708760

709-
// 传输完成,广播统计信息
710-
c.logger.Event("Transfer complete: TRAFFIC_STATS|TCP_RX=0|TCP_TX=0|UDP_RX=%v|UDP_TX=%v", rx, tx)
761+
<-done
762+
targetConn.Close()
763+
c.tunnelPool.Put(id, remoteConn)
764+
c.logger.Debug("Tunnel connection: put %v -> pool active %v", id, c.tunnelPool.Active())
711765
}
712766

713767
// singleLoop 单端转发处理循环
@@ -849,9 +903,10 @@ func (c *Common) singleUDPLoop() error {
849903
// 从UDP读取响应
850904
n, err := targetConn.Read(buffer)
851905
if err != nil {
852-
// 检查是否为超时错误
853906
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
854907
c.logger.Debug("UDP session abort: %v", err)
908+
} else if strings.Contains(err.Error(), "use of closed network connection") {
909+
c.logger.Debug("Read closed: %v", err)
855910
} else {
856911
c.logger.Error("Read failed: %v", err)
857912
}

0 commit comments

Comments
 (0)