Skip to content

Commit 5260464

Browse files
authored
feat: implement dialWithRotation for TCP and UDP connections to enhance fault tolerance and load balancing
1 parent 98a5274 commit 5260464

File tree

1 file changed

+44
-21
lines changed

1 file changed

+44
-21
lines changed

internal/common.go

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,45 @@ func (c *Common) nextTargetAddr() {
283283
}
284284

285285
// 轮询下一个地址
286-
newIdx := max((atomic.AddInt32(&c.targetIdx, 1)-1)%int32(len(c.targetTCPAddrs)), 0)
286+
newIdx := atomic.AddInt32(&c.targetIdx, 1) % int32(len(c.targetTCPAddrs))
287287
c.targetTCPAddr = c.targetTCPAddrs[newIdx]
288288
c.targetUDPAddr = c.targetUDPAddrs[newIdx]
289289
}
290290

291+
// dialWithRotation 轮询拨号到目标地址组
292+
func (c *Common) dialWithRotation(network string, timeout time.Duration) (net.Conn, error) {
293+
maxRetries := max(len(c.targetTCPAddrs), 1)
294+
var lastErr error
295+
296+
for i := range maxRetries {
297+
// 故障转移
298+
if i > 0 {
299+
c.nextTargetAddr()
300+
}
301+
302+
var targetAddr string
303+
if network == "tcp" {
304+
targetAddr = c.targetTCPAddr.String()
305+
} else {
306+
targetAddr = c.targetUDPAddr.String()
307+
}
308+
309+
// 负载均衡
310+
conn, err := net.DialTimeout(network, targetAddr, timeout)
311+
if err == nil {
312+
c.nextTargetAddr()
313+
return conn, nil
314+
}
315+
316+
lastErr = err
317+
}
318+
319+
if maxRetries > 1 {
320+
return nil, fmt.Errorf("dialWithRotation: all %d targets failed: %w", maxRetries, lastErr)
321+
}
322+
return nil, lastErr
323+
}
324+
291325
// getTunnelKey 从URL中获取隧道密钥
292326
func (c *Common) getTunnelKey(parsedURL *url.URL) {
293327
if key := parsedURL.User.Username(); key != "" {
@@ -1090,13 +1124,10 @@ func (c *Common) commonTCPOnce(signalURL *url.URL) {
10901124

10911125
defer c.releaseSlot(false)
10921126

1093-
// 轮询下一个目标地址
1094-
c.nextTargetAddr()
1095-
10961127
// 连接到目标TCP地址
1097-
targetConn, err := net.DialTimeout("tcp", c.targetTCPAddr.String(), tcpDialTimeout)
1128+
targetConn, err := c.dialWithRotation("tcp", tcpDialTimeout)
10981129
if err != nil {
1099-
c.logger.Error("commonTCPOnce: dialTimeout failed: %v", err)
1130+
c.logger.Error("commonTCPOnce: dialWithRotation failed: %v", err)
11001131
return
11011132
}
11021133

@@ -1179,12 +1210,10 @@ func (c *Common) commonUDPOnce(signalURL *url.URL) {
11791210
return
11801211
}
11811212

1182-
// 轮询下一个目标地址
1183-
c.nextTargetAddr()
1184-
1185-
newSession, err := net.DialTimeout("udp", c.targetUDPAddr.String(), udpDialTimeout)
1213+
// 创建新的会话
1214+
newSession, err := c.dialWithRotation("udp", udpDialTimeout)
11861215
if err != nil {
1187-
c.logger.Error("commonUDPOnce: dialTimeout failed: %v", err)
1216+
c.logger.Error("commonUDPOnce: dialWithRotation failed: %v", err)
11881217
c.releaseSlot(true)
11891218
return
11901219
}
@@ -1363,13 +1392,10 @@ func (c *Common) singleTCPLoop() error {
13631392

13641393
defer c.releaseSlot(false)
13651394

1366-
// 轮询下一个目标地址
1367-
c.nextTargetAddr()
1368-
13691395
// 尝试建立目标连接
1370-
targetConn, err := net.DialTimeout("tcp", c.targetTCPAddr.String(), tcpDialTimeout)
1396+
targetConn, err := c.dialWithRotation("tcp", tcpDialTimeout)
13711397
if err != nil {
1372-
c.logger.Error("singleTCPLoop: dialTimeout failed: %v", err)
1398+
c.logger.Error("singleTCPLoop: dialWithRotation failed: %v", err)
13731399
return
13741400
}
13751401

@@ -1445,13 +1471,10 @@ func (c *Common) singleUDPLoop() error {
14451471
continue
14461472
}
14471473

1448-
// 轮询下一个目标地址
1449-
c.nextTargetAddr()
1450-
14511474
// 创建新的会话
1452-
newSession, err := net.DialTimeout("udp", c.targetUDPAddr.String(), udpDialTimeout)
1475+
newSession, err := c.dialWithRotation("udp", udpDialTimeout)
14531476
if err != nil {
1454-
c.logger.Error("singleUDPLoop: dialTimeout failed: %v", err)
1477+
c.logger.Error("singleUDPLoop: dialWithRotation failed: %v", err)
14551478
c.releaseSlot(true)
14561479
c.putUDPBuffer(buffer)
14571480
continue

0 commit comments

Comments
 (0)