Skip to content

Commit 7de24e2

Browse files
committed
fix: StreamGunWithConn not synchronously close the incoming net.Conn
1 parent 622d99d commit 7de24e2

File tree

6 files changed

+42
-23
lines changed

6 files changed

+42
-23
lines changed

adapter/outbound/base.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn {
281281
epc := N.NewEnhancePacketConn(pc)
282282
if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn
283283
epc = N.NewDeadlineEnhancePacketConn(epc) // most conn from outbound can't handle readDeadline correctly
284-
epc = N.NewRefPacketConn(epc, a) // add ref for autoCloseProxyAdapter
284+
epc = N.NewRefPacketConn(epc, a) // add ref for autoCloseProxyAdapter
285285
}
286286
return &packetConn{epc, []string{a.Name()}, a.Name(), utils.NewUUIDV4().String(), parseRemoteDestination(a.Addr())}
287287
}

adapter/outbound/trojan.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
313313
}
314314

315315
if option.Network == "grpc" {
316-
dialFn := func(network, addr string) (net.Conn, error) {
316+
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
317317
var err error
318318
var cDialer C.Dialer = dialer.NewDialer(t.Base.DialOptions()...)
319319
if len(t.option.DialerProxy) > 0 {
@@ -322,7 +322,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
322322
return nil, err
323323
}
324324
}
325-
c, err := cDialer.DialContext(context.Background(), "tcp", t.addr)
325+
c, err := cDialer.DialContext(ctx, "tcp", t.addr)
326326
if err != nil {
327327
return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error())
328328
}

adapter/outbound/vless.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ func NewVless(option VlessOption) (*Vless, error) {
571571
option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com")
572572
}
573573
case "grpc":
574-
dialFn := func(network, addr string) (net.Conn, error) {
574+
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
575575
var err error
576576
var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...)
577577
if len(v.option.DialerProxy) > 0 {
@@ -580,7 +580,7 @@ func NewVless(option VlessOption) (*Vless, error) {
580580
return nil, err
581581
}
582582
}
583-
c, err := cDialer.DialContext(context.Background(), "tcp", v.addr)
583+
c, err := cDialer.DialContext(ctx, "tcp", v.addr)
584584
if err != nil {
585585
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
586586
}

adapter/outbound/vmess.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ func NewVmess(option VmessOption) (*Vmess, error) {
478478
option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com")
479479
}
480480
case "grpc":
481-
dialFn := func(network, addr string) (net.Conn, error) {
481+
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
482482
var err error
483483
var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...)
484484
if len(v.option.DialerProxy) > 0 {
@@ -487,7 +487,7 @@ func NewVmess(option VmessOption) (*Vmess, error) {
487487
return nil, err
488488
}
489489
}
490-
c, err := cDialer.DialContext(context.Background(), "tcp", v.addr)
490+
c, err := cDialer.DialContext(ctx, "tcp", v.addr)
491491
if err != nil {
492492
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
493493
}

transport/gun/gun.go

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ var defaultHeader = http.Header{
3636
"user-agent": []string{"grpc-go/1.36.0"},
3737
}
3838

39-
type DialFn = func(network, addr string) (net.Conn, error)
39+
type DialFn = func(ctx context.Context, network, addr string) (net.Conn, error)
4040

4141
type Conn struct {
42-
initFn func() (io.ReadCloser, netAddr, error)
43-
writer io.Writer
44-
flusher http.Flusher
42+
initFn func() (io.ReadCloser, netAddr, error)
43+
writer io.Writer
44+
closer io.Closer
4545
netAddr
4646

4747
reader io.ReadCloser
@@ -149,8 +149,8 @@ func (g *Conn) Write(b []byte) (n int, err error) {
149149
err = g.err
150150
}
151151

152-
if g.flusher != nil {
153-
g.flusher.Flush()
152+
if flusher, ok := g.writer.(http.Flusher); ok {
153+
flusher.Flush()
154154
}
155155

156156
return len(b), err
@@ -172,8 +172,8 @@ func (g *Conn) WriteBuffer(buffer *buf.Buffer) error {
172172
err = g.err
173173
}
174174

175-
if g.flusher != nil {
176-
g.flusher.Flush()
175+
if flusher, ok := g.writer.(http.Flusher); ok {
176+
flusher.Flush()
177177
}
178178

179179
return err
@@ -185,14 +185,27 @@ func (g *Conn) FrontHeadroom() int {
185185

186186
func (g *Conn) Close() error {
187187
g.close.Store(true)
188+
var errorArr []error
189+
188190
if reader := g.reader; reader != nil {
189-
reader.Close()
191+
if err := reader.Close(); err != nil {
192+
errorArr = append(errorArr, err)
193+
}
190194
}
191195

192196
if closer, ok := g.writer.(io.Closer); ok {
193-
return closer.Close()
197+
if err := closer.Close(); err != nil {
198+
errorArr = append(errorArr, err)
199+
}
194200
}
195-
return nil
201+
202+
if closer := g.closer; closer != nil {
203+
if err := closer.Close(); err != nil {
204+
errorArr = append(errorArr, err)
205+
}
206+
}
207+
208+
return errors.Join(errorArr...)
196209
}
197210

198211
func (g *Conn) SetReadDeadline(t time.Time) error { return g.SetDeadline(t) }
@@ -212,7 +225,7 @@ func (g *Conn) SetDeadline(t time.Time) error {
212225

213226
func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, Fingerprint string, realityConfig *tlsC.RealityConfig) *TransportWrap {
214227
dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
215-
pconn, err := dialFn(network, addr)
228+
pconn, err := dialFn(ctx, network, addr)
216229
if err != nil {
217230
return nil, err
218231
}
@@ -327,10 +340,17 @@ func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, er
327340
}
328341

329342
func StreamGunWithConn(conn net.Conn, tlsConfig *tls.Config, cfg *Config, realityConfig *tlsC.RealityConfig) (net.Conn, error) {
330-
dialFn := func(network, addr string) (net.Conn, error) {
343+
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
331344
return conn, nil
332345
}
333346

334347
transport := NewHTTP2Client(dialFn, tlsConfig, cfg.ClientFingerprint, realityConfig)
335-
return StreamGunWithTransport(transport, cfg)
348+
c, err := StreamGunWithTransport(transport, cfg)
349+
if err != nil {
350+
return nil, err
351+
}
352+
if c, ok := c.(*Conn); ok { // The incoming net.Conn should be closed synchronously with the generated gun.Conn
353+
c.closer = conn
354+
}
355+
return c, nil
336356
}

transport/gun/server.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ func NewServerHandler(options ServerOption) http.Handler {
5656
}
5757
return request.Body, nAddr, nil
5858
},
59-
writer: writer,
60-
flusher: writer.(http.Flusher),
59+
writer: writer,
6160
}
6261

6362
wrapper := &h2ConnWrapper{

0 commit comments

Comments
 (0)