Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions listener/anytls/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ func (l *Listener) HandleConn(conn net.Conn, h *sing.ListenerHandler) {
return
}

// It seems that mihomo does not implement a connection error reporting mechanism, so we report success directly.
err = stream.HandshakeSuccess()
if err != nil {
return
}

h.NewConnection(ctx, stream, M.Metadata{
Source: M.SocksaddrFromNet(conn.RemoteAddr()),
Destination: destination,
Expand Down
6 changes: 3 additions & 3 deletions transport/anytls/padding/padding.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import (
const CheckMark = -1

var DefaultPaddingScheme = []byte(`stop=8
0=34-120
0=30-30
1=100-400
2=400-500,c,500-1000,c,400-500,c,500-1000,c,500-1000,c,400-500
3=500-1000
2=400-500,c,500-1000,c,500-1000,c,500-1000,c,500-1000
3=9-9,500-1000
4=500-1000
5=500-1000
6=500-1000
Expand Down
10 changes: 3 additions & 7 deletions transport/anytls/session/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) {
}

stream.dieHook = func() {
if session.IsClosed() {
if session.dieHook != nil {
session.dieHook()
}
} else {
if !session.IsClosed() {
select {
case <-c.die.Done():
// Now client has been closed
Expand Down Expand Up @@ -154,10 +150,10 @@ func (c *Client) Close() error {

c.sessionsLock.Lock()
sessionToClose := make([]*Session, 0, len(c.sessions))
for seq, session := range c.sessions {
for _, session := range c.sessions {
sessionToClose = append(sessionToClose, session)
delete(c.sessions, seq)
}
c.sessions = make(map[uint64]*Session)
c.sessionsLock.Unlock()

for _, session := range sessionToClose {
Expand Down
7 changes: 6 additions & 1 deletion transport/anytls/session/frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ const ( // cmds
cmdSYN = 1 // stream open
cmdPSH = 2 // data push
cmdFIN = 3 // stream close, a.k.a EOF mark
cmdSettings = 4 // Settings
cmdSettings = 4 // Settings (Client send to Server)
cmdAlert = 5 // Alert
cmdUpdatePaddingScheme = 6 // update padding scheme
// Since version 2
cmdSYNACK = 7 // Server reports to the client that the stream has been opened
cmdHeartRequest = 8 // Keep alive command
cmdHeartResponse = 9 // Keep alive command
cmdServerSettings = 10 // Settings (Server send to client)
)

const (
Expand Down
104 changes: 93 additions & 11 deletions transport/anytls/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package session
import (
"crypto/md5"
"encoding/binary"
"fmt"
"io"
"net"
"runtime/debug"
"strconv"
"sync"
"time"

Expand All @@ -30,11 +32,16 @@ type Session struct {
die chan struct{}
dieHook func()

synDone func()
synDoneLock sync.Mutex

// pool
seq uint64
idleSince time.Time
padding *atomic.TypedValue[*padding.PaddingFactory]

peerVersion byte

// client
isClient bool
sendPadding bool
Expand Down Expand Up @@ -76,7 +83,7 @@ func (s *Session) Run() {
}

settings := util.StringMap{
"v": "1",
"v": "2",
"client": "mihomo/" + constant.Version,
"padding-md5": s.padding.Load().Md5,
}
Expand Down Expand Up @@ -105,15 +112,16 @@ func (s *Session) Close() error {
close(s.die)
once = true
})

if once {
if s.dieHook != nil {
s.dieHook()
s.dieHook = nil
}
s.streamLock.Lock()
for k := range s.streams {
s.streams[k].sessionClose()
for _, stream := range s.streams {
stream.Close()
}
s.streams = make(map[uint32]*Stream)
s.streamLock.Unlock()
return s.conn.Close()
} else {
Expand All @@ -132,6 +140,17 @@ func (s *Session) OpenStream() (*Stream, error) {

//logrus.Debugln("stream open", sid, s.streams)

if sid >= 2 && s.peerVersion >= 2 {
s.synDoneLock.Lock()
if s.synDone != nil {
s.synDone()
}
s.synDone = util.NewDeadlineWatcher(time.Second*3, func() {
s.Close()
})
s.synDoneLock.Unlock()
}

if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
return nil, err
}
Expand Down Expand Up @@ -195,13 +214,37 @@ func (s *Session) recvLoop() error {
if _, ok := s.streams[sid]; !ok {
stream := newStream(sid, s)
s.streams[sid] = stream
if s.onNewStream != nil {
go s.onNewStream(stream)
} else {
go s.Close()
}
go func() {
if s.onNewStream != nil {
s.onNewStream(stream)
} else {
stream.Close()
}
}()
}
s.streamLock.Unlock()
case cmdSYNACK: // should be client only
s.synDoneLock.Lock()
if s.synDone != nil {
s.synDone()
s.synDone = nil
}
s.synDoneLock.Unlock()
if hdr.Length() > 0 {
buffer := pool.Get(int(hdr.Length()))
if _, err := io.ReadFull(s.conn, buffer); err != nil {
pool.Put(buffer)
return err
}
// report error
s.streamLock.RLock()
stream, ok := s.streams[sid]
s.streamLock.RUnlock()
if ok {
stream.CloseWithError(fmt.Errorf("remote: %s", string(buffer)))
}
pool.Put(buffer)
}
case cmdFIN:
s.streamLock.RLock()
stream, ok := s.streams[sid]
Expand Down Expand Up @@ -240,6 +283,20 @@ func (s *Session) recvLoop() error {
return err
}
}
// check client's version
if v, err := strconv.Atoi(m["v"]); err == nil && v >= 2 {
s.peerVersion = byte(v)
// send cmdServerSettings
f := newFrame(cmdServerSettings, 0)
f.data = util.StringMap{
"v": "2",
}.ToBytes()
_, err = s.writeFrame(f)
if err != nil {
pool.Put(buffer)
return err
}
}
}
pool.Put(buffer)
}
Expand All @@ -265,12 +322,35 @@ func (s *Session) recvLoop() error {
}
if s.isClient {
if padding.UpdatePaddingScheme(rawScheme, s.padding) {
log.Infoln("[Update padding succeed] %x\n", md5.Sum(rawScheme))
log.Debugln("[Update padding succeed] %x\n", md5.Sum(rawScheme))
} else {
log.Warnln("[Update padding failed] %x\n", md5.Sum(rawScheme))
}
}
}
case cmdHeartRequest:
if _, err := s.writeFrame(newFrame(cmdHeartResponse, sid)); err != nil {
return err
}
case cmdHeartResponse:
// Active keepalive checking is not implemented yet
break
case cmdServerSettings:
if hdr.Length() > 0 {
buffer := pool.Get(int(hdr.Length()))
if _, err := io.ReadFull(s.conn, buffer); err != nil {
pool.Put(buffer)
return err
}
if s.isClient {
// check server's version
m := util.StringMapFromBytes(buffer)
if v, err := strconv.Atoi(m["v"]); err == nil {
s.peerVersion = byte(v)
}
}
pool.Put(buffer)
}
default:
// I don't know what command it is (can't have data)
}
Expand All @@ -280,8 +360,10 @@ func (s *Session) recvLoop() error {
}
}

// notify the session that a stream has closed
func (s *Session) streamClosed(sid uint32) error {
if s.IsClosed() {
return io.ErrClosedPipe
}
_, err := s.writeFrame(newFrame(cmdFIN, sid))
s.streamLock.Lock()
delete(s.streams, sid)
Expand Down
62 changes: 51 additions & 11 deletions transport/anytls/session/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type Stream struct {

dieOnce sync.Once
dieHook func()
dieErr error

reportOnce sync.Once
}

// newStream initiates a Stream struct
Expand All @@ -36,7 +39,11 @@ func newStream(id uint32, sess *Session) *Stream {

// Read implements net.Conn
func (s *Stream) Read(b []byte) (n int, err error) {
return s.pipeR.Read(b)
n, err = s.pipeR.Read(b)
if s.dieErr != nil {
err = s.dieErr
}
return
}

// Write implements net.Conn
Expand All @@ -54,25 +61,28 @@ func (s *Stream) Write(b []byte) (n int, err error) {

// Close implements net.Conn
func (s *Stream) Close() error {
if s.sessionClose() {
// notify remote
return s.sess.streamClosed(s.id)
} else {
return io.ErrClosedPipe
}
return s.CloseWithError(io.ErrClosedPipe)
}

// sessionClose close stream from session side, do not notify remote
func (s *Stream) sessionClose() (once bool) {
func (s *Stream) CloseWithError(err error) error {
// if err != io.ErrClosedPipe {
// logrus.Debugln(err)
// }
var once bool
s.dieOnce.Do(func() {
s.dieErr = err
s.pipeR.Close()
once = true
})
if once {
if s.dieHook != nil {
s.dieHook()
s.dieHook = nil
}
})
return
return s.sess.streamClosed(s.id)
} else {
return s.dieErr
}
}

func (s *Stream) SetReadDeadline(t time.Time) error {
Expand Down Expand Up @@ -108,3 +118,33 @@ func (s *Stream) RemoteAddr() net.Addr {
}
return nil
}

// HandshakeFailure should be called when Server fail to create outbound proxy
func (s *Stream) HandshakeFailure(err error) error {
var once bool
s.reportOnce.Do(func() {
once = true
})
if once && err != nil && s.sess.peerVersion >= 2 {
f := newFrame(cmdSYNACK, s.id)
f.data = []byte(err.Error())
if _, err := s.sess.writeFrame(f); err != nil {
return err
}
}
return nil
}

// HandshakeSuccess should be called when Server success to create outbound proxy
func (s *Stream) HandshakeSuccess() error {
var once bool
s.reportOnce.Do(func() {
once = true
})
if once && s.sess.peerVersion >= 2 {
if _, err := s.sess.writeFrame(newFrame(cmdSYNACK, s.id)); err != nil {
return err
}
}
return nil
}
25 changes: 25 additions & 0 deletions transport/anytls/util/deadline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package util

import (
"sync"
"time"
)

func NewDeadlineWatcher(ddl time.Duration, timeOut func()) (done func()) {
t := time.NewTimer(ddl)
closeCh := make(chan struct{})
go func() {
defer t.Stop()
select {
case <-closeCh:
case <-t.C:
timeOut()
}
}()
var once sync.Once
return func() {
once.Do(func() {
close(closeCh)
})
}
}