Skip to content

Commit 4aa2d26

Browse files
authored
refactor: enhance error handling with contextual messages in client and server methods
1 parent cc59f0e commit 4aa2d26

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

internal/client.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"bufio"
66
"bytes"
77
"context"
8+
"fmt"
89
"io"
910
"net"
1011
"net/url"
@@ -95,7 +96,7 @@ func (c *Client) start() error {
9596
switch c.runMode {
9697
case "1": // 单端模式
9798
if err := c.initTunnelListener(); err != nil {
98-
return err
99+
return fmt.Errorf("start: initTunnelListener failed: %w", err)
99100
}
100101
return c.singleStart()
101102
case "2": // 双端模式
@@ -113,14 +114,17 @@ func (c *Client) start() error {
113114

114115
// singleStart 启动单端转发模式
115116
func (c *Client) singleStart() error {
116-
return c.singleControl()
117+
if err := c.singleControl(); err != nil {
118+
return fmt.Errorf("singleStart: singleControl failed: %w", err)
119+
}
120+
return nil
117121
}
118122

119123
// commonStart 启动双端握手模式
120124
func (c *Client) commonStart() error {
121125
// 与隧道服务端进行握手
122126
if err := c.tunnelHandshake(); err != nil {
123-
return err
127+
return fmt.Errorf("commonStart: tunnelHandshake failed: %w", err)
124128
}
125129

126130
// 初始化连接池
@@ -140,19 +144,22 @@ func (c *Client) commonStart() error {
140144
if c.dataFlow == "+" {
141145
// 初始化目标监听器
142146
if err := c.initTargetListener(); err != nil {
143-
return err
147+
return fmt.Errorf("commonStart: initTargetListener failed: %w", err)
144148
}
145149
go c.commonLoop()
146150
}
147-
return c.commonControl()
151+
if err := c.commonControl(); err != nil {
152+
return fmt.Errorf("commonStart: commonControl failed: %w", err)
153+
}
154+
return nil
148155
}
149156

150157
// tunnelHandshake 与隧道服务端进行握手
151158
func (c *Client) tunnelHandshake() error {
152159
// 建立隧道TCP连接
153160
tunnelTCPConn, err := net.DialTimeout("tcp", c.tunnelTCPAddr.String(), tcpDialTimeout)
154161
if err != nil {
155-
return err
162+
return fmt.Errorf("tunnelHandshake: dialTimeout failed: %w", err)
156163
}
157164

158165
c.tunnelTCPConn = tunnelTCPConn.(*net.TCPConn)
@@ -163,27 +170,27 @@ func (c *Client) tunnelHandshake() error {
163170
// 发送隧道密钥
164171
_, err = c.tunnelTCPConn.Write(append(c.xor([]byte(c.tunnelKey)), '\n'))
165172
if err != nil {
166-
return err
173+
return fmt.Errorf("tunnelHandshake: write tunnel key failed: %w", err)
167174
}
168175

169176
// 读取隧道URL
170177
rawTunnelURL, err := c.bufReader.ReadBytes('\n')
171178
if err != nil {
172-
return err
179+
return fmt.Errorf("tunnelHandshake: readBytes failed: %w", err)
173180
}
174181

175182
// 解析隧道URL
176183
tunnelURL, err := url.Parse(string(c.xor(bytes.TrimSuffix(rawTunnelURL, []byte{'\n'}))))
177184
if err != nil {
178-
return err
185+
return fmt.Errorf("tunnelHandshake: parse tunnel URL failed: %w", err)
179186
}
180187

181188
// 更新客户端配置
182189
if tunnelURL.Host == "" || tunnelURL.Path == "" || tunnelURL.Fragment == "" {
183190
return net.UnknownNetworkError(tunnelURL.String())
184191
}
185192
if max, err := strconv.Atoi(tunnelURL.Host); err != nil {
186-
return err
193+
return fmt.Errorf("tunnelHandshake: parse max pool capacity failed: %w", err)
187194
} else {
188195
c.maxPoolCapacity = max
189196
}

internal/server.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"bytes"
77
"context"
88
"crypto/tls"
9+
"fmt"
910
"io"
1011
"net"
1112
"net/url"
@@ -95,14 +96,14 @@ func (s *Server) start() error {
9596

9697
// 初始化隧道监听器
9798
if err := s.initTunnelListener(); err != nil {
98-
return err
99+
return fmt.Errorf("start: initTunnelListener failed: %w", err)
99100
}
100101

101102
// 运行模式判断
102103
switch s.runMode {
103104
case "1": // 反向模式
104105
if err := s.initTargetListener(); err != nil {
105-
return err
106+
return fmt.Errorf("start: initTargetListener failed: %w", err)
106107
}
107108
s.dataFlow = "-"
108109
case "2": // 正向模式
@@ -119,7 +120,7 @@ func (s *Server) start() error {
119120

120121
// 与客户端进行握手
121122
if err := s.tunnelHandshake(); err != nil {
122-
return err
123+
return fmt.Errorf("start: tunnelHandshake failed: %w", err)
123124
}
124125

125126
// 握手之后把UDP监听关掉
@@ -139,23 +140,26 @@ func (s *Server) start() error {
139140
if s.dataFlow == "-" {
140141
go s.commonLoop()
141142
}
142-
return s.commonControl()
143+
if err := s.commonControl(); err != nil {
144+
return fmt.Errorf("start: commonControl failed: %w", err)
145+
}
146+
return nil
143147
}
144148

145149
// tunnelHandshake 与客户端进行握手
146150
func (s *Server) tunnelHandshake() error {
147151
// 接受隧道连接
148152
for {
149153
if s.ctx.Err() != nil {
150-
return s.ctx.Err()
154+
return fmt.Errorf("tunnelHandshake: context error: %w", s.ctx.Err())
151155
}
152156

153157
tunnelTCPConn, err := s.tunnelListener.Accept()
154158
if err != nil {
155-
s.logger.Error("Accept error: %v", err)
159+
s.logger.Error("tunnelHandshake: accept error: %v", err)
156160
select {
157161
case <-s.ctx.Done():
158-
return s.ctx.Err()
162+
return fmt.Errorf("tunnelHandshake: context error: %w", s.ctx.Err())
159163
case <-time.After(serviceCooldown):
160164
}
161165
continue
@@ -166,11 +170,11 @@ func (s *Server) tunnelHandshake() error {
166170
bufReader := bufio.NewReader(tunnelTCPConn)
167171
rawTunnelKey, err := bufReader.ReadString('\n')
168172
if err != nil {
169-
s.logger.Warn("Handshake timeout: %v", tunnelTCPConn.RemoteAddr())
173+
s.logger.Warn("tunnelHandshake: handshake timeout: %v", tunnelTCPConn.RemoteAddr())
170174
tunnelTCPConn.Close()
171175
select {
172176
case <-s.ctx.Done():
173-
return s.ctx.Err()
177+
return fmt.Errorf("tunnelHandshake: context error: %w", s.ctx.Err())
174178
case <-time.After(serviceCooldown):
175179
}
176180
continue
@@ -180,11 +184,11 @@ func (s *Server) tunnelHandshake() error {
180184
tunnelKey := string(s.xor(bytes.TrimSuffix([]byte(rawTunnelKey), []byte{'\n'})))
181185

182186
if tunnelKey != s.tunnelKey {
183-
s.logger.Warn("Access denied: %v", tunnelTCPConn.RemoteAddr())
187+
s.logger.Warn("tunnelHandshake: access denied: %v", tunnelTCPConn.RemoteAddr())
184188
tunnelTCPConn.Close()
185189
select {
186190
case <-s.ctx.Done():
187-
return s.ctx.Err()
191+
return fmt.Errorf("tunnelHandshake: context error: %w", s.ctx.Err())
188192
case <-time.After(serviceCooldown):
189193
}
190194
continue
@@ -210,7 +214,7 @@ func (s *Server) tunnelHandshake() error {
210214

211215
_, err := s.tunnelTCPConn.Write(append(s.xor([]byte(tunnelURL.String())), '\n'))
212216
if err != nil {
213-
return err
217+
return fmt.Errorf("tunnelHandshake: write tunnel config failed: %w", err)
214218
}
215219

216220
s.logger.Info("Tunnel signal -> : %v -> %v", tunnelURL.String(), s.tunnelTCPConn.RemoteAddr())

0 commit comments

Comments
 (0)