Skip to content

Commit bd4508a

Browse files
committed
Improve splice error handling
1 parent 33f2950 commit bd4508a

File tree

5 files changed

+62
-50
lines changed

5 files changed

+62
-50
lines changed

common/bufio/copy_direct.go

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,12 @@ func copyDirect(source io.Reader, destination io.Writer, readCounters []N.CountF
1313
if !N.SyscallAvailableForRead(source) || !N.SyscallAvailableForWrite(destination) {
1414
return
1515
}
16-
sourceConn := N.SyscallConnForRead(source)
17-
destinationConn := N.SyscallConnForWrite(destination)
16+
sourceReader, sourceConn := N.SyscallConnForRead(source)
17+
destinationWriter, destinationConn := N.SyscallConnForWrite(destination)
1818
if sourceConn == nil || destinationConn == nil {
1919
return
2020
}
21-
rawSource, err := sourceConn.SyscallConn()
22-
if err != nil {
23-
return
24-
}
25-
rawDestination, err := destinationConn.SyscallConn()
26-
if err != nil {
27-
return
28-
}
29-
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
21+
handed, n, err = splice(sourceConn, sourceReader, destinationConn, destinationWriter, readCounters, writeCounters)
3022
return
3123
}
3224

common/bufio/splice_linux.go

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111

1212
const maxSpliceSize = 1 << 20
1313

14-
func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
14+
func splice(source syscall.RawConn, sourceReader N.SyscallReader, destination syscall.RawConn, destinationWriter N.SyscallWriter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
1515
handed = true
1616
var pipeFDs [2]int
1717
err = unix.Pipe2(pipeFDs[:], syscall.O_CLOEXEC|syscall.O_NONBLOCK)
@@ -20,12 +20,14 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []
2020
}
2121
defer unix.Close(pipeFDs[0])
2222
defer unix.Close(pipeFDs[1])
23-
2423
_, _ = unix.FcntlInt(uintptr(pipeFDs[0]), unix.F_SETPIPE_SZ, maxSpliceSize)
25-
var readN int
26-
var readErr error
27-
var writeSize int
28-
var writeErr error
24+
var (
25+
readN int
26+
readErr error
27+
writeSize int
28+
writeErr error
29+
notFirstTime bool
30+
)
2931
readFunc := func(fd uintptr) (done bool) {
3032
p0, p1 := unix.Splice(int(fd), nil, pipeFDs[1], nil, maxSpliceSize, unix.SPLICE_F_NONBLOCK)
3133
readN = int(p0)
@@ -46,34 +48,49 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []
4648
}
4749
for {
4850
err = source.Read(readFunc)
49-
if err != nil {
50-
readErr = err
51-
}
5251
if readErr != nil {
53-
if readErr == unix.EINVAL || readErr == unix.ENOSYS {
52+
err = readErr
53+
}
54+
if err != nil {
55+
if sourceReader != nil {
56+
newBuffer, newErr := sourceReader.HandleSyscallReadError(err)
57+
if newErr != nil {
58+
err = newErr
59+
} else {
60+
err = nil
61+
if len(newBuffer) > 0 {
62+
readN, readErr = unix.Write(pipeFDs[1], newBuffer)
63+
if readErr != nil {
64+
err = E.Cause(err, "write handled data")
65+
}
66+
}
67+
}
68+
} else if !notFirstTime && E.IsMulti(err, unix.EINVAL, unix.ENOSYS) {
5469
handed = false
5570
return
5671
}
57-
err = E.Cause(readErr, "splice read")
72+
err = E.Cause(err, "splice read")
5873
return
5974
}
6075
if readN == 0 {
6176
return
6277
}
6378
writeSize = readN
6479
err = destination.Write(writeFunc)
65-
if err != nil {
66-
writeErr = err
67-
}
6880
if writeErr != nil {
69-
err = E.Cause(writeErr, "splice write")
81+
err = writeErr
82+
}
83+
if err != nil {
84+
err = E.Cause(err, "splice write")
7085
return
7186
}
87+
n += int64(readN)
7288
for _, readCounter := range readCounters {
7389
readCounter(int64(readN))
7490
}
7591
for _, writeCounter := range writeCounters {
7692
writeCounter(int64(readN))
7793
}
94+
notFirstTime = true
7895
}
7996
}

common/bufio/splice_stub.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ import (
88
N "github.com/sagernet/sing/common/network"
99
)
1010

11-
func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
11+
func splice(source syscall.RawConn, sourceReader N.SyscallReader, destination syscall.RawConn, destinationWriter N.SyscallWriter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
1212
return
1313
}

common/network/counter.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func UnwrapCountReader(reader io.Reader, countFunc []CountFunc) (io.Reader, []Co
3939
return reader, countFunc
4040
}
4141
switch u := reader.(type) {
42-
case ReadWaiter, ReadWaitCreator, syscall.Conn, SyscallReadCreator:
42+
case ReadWaiter, ReadWaitCreator, syscall.Conn, SyscallReader:
4343
// In our use cases, counters is always at the top, so we stop when we encounter ReadWaiter
4444
return reader, countFunc
4545
case WithUpstreamReader:
@@ -60,7 +60,7 @@ func UnwrapCountWriter(writer io.Writer, countFunc []CountFunc) (io.Writer, []Co
6060
return writer, countFunc
6161
}
6262
switch u := writer.(type) {
63-
case syscall.Conn, SyscallWriteCreator:
63+
case syscall.Conn, SyscallWriter:
6464
// In our use cases, counters is always at the top, so we stop when we encounter syscall conn
6565
return writer, countFunc
6666
case WithUpstreamWriter:
@@ -81,7 +81,7 @@ func UnwrapCountPacketReader(reader PacketReader, countFunc []CountFunc) (Packet
8181
return reader, countFunc
8282
}
8383
switch u := reader.(type) {
84-
case PacketReadWaiter, PacketReadWaitCreator, syscall.Conn, SyscallWriteCreator:
84+
case PacketReadWaiter, PacketReadWaitCreator, syscall.Conn:
8585
// In our use cases, counters is always at the top, so we stop when we encounter ReadWaiter
8686
return reader, countFunc
8787
case WithUpstreamReader:
@@ -103,7 +103,7 @@ func UnwrapCountPacketWriter(writer PacketWriter, countFunc []CountFunc) (Packet
103103
return writer, countFunc
104104
}
105105
switch u := writer.(type) {
106-
case syscall.Conn, SyscallWriteCreator:
106+
case syscall.Conn:
107107
// In our use cases, counters is always at the top, so we stop when we encounter syscall conn
108108
return writer, countFunc
109109
case WithUpstreamWriter:

common/network/direct.go

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,16 @@ type VectorisedPacketReadWaitCreator interface {
114114
CreateVectorisedPacketReadWaiter() (VectorisedPacketReadWaiter, bool)
115115
}
116116

117-
type SyscallReadCreator interface {
118-
SyscallConnForRead() syscall.Conn
117+
type SyscallReader interface {
118+
SyscallConnForRead() syscall.RawConn
119+
HandleSyscallReadError(inputErr error) ([]byte, error)
119120
}
120121

121122
func SyscallAvailableForRead(reader io.Reader) bool {
122123
if _, ok := reader.(syscall.Conn); ok {
123124
return true
124125
}
125-
if _, ok := reader.(SyscallReadCreator); ok {
126+
if _, ok := reader.(SyscallReader); ok {
126127
return true
127128
}
128129
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
@@ -137,34 +138,35 @@ func SyscallAvailableForRead(reader io.Reader) bool {
137138
return false
138139
}
139140

140-
func SyscallConnForRead(reader io.Reader) syscall.Conn {
141+
func SyscallConnForRead(reader io.Reader) (SyscallReader, syscall.RawConn) {
141142
if c, ok := reader.(syscall.Conn); ok {
142-
return c
143+
conn, _ := c.SyscallConn()
144+
return nil, conn
143145
}
144-
if c, ok := reader.(SyscallReadCreator); ok {
145-
return c.SyscallConnForRead()
146+
if c, ok := reader.(SyscallReader); ok {
147+
return c, c.SyscallConnForRead()
146148
}
147-
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
148-
return nil
149+
if u, ok := reader.(ReaderWithUpstream); !ok || u.ReaderReplaceable() {
150+
return nil, nil
149151
}
150152
if u, ok := reader.(WithUpstreamReader); ok {
151153
return SyscallConnForRead(u.UpstreamReader().(io.Reader))
152154
}
153155
if u, ok := reader.(common.WithUpstream); ok {
154156
return SyscallConnForRead(u.Upstream().(io.Reader))
155157
}
156-
return nil
158+
return nil, nil
157159
}
158160

159-
type SyscallWriteCreator interface {
160-
SyscallConnForWrite() syscall.Conn
161+
type SyscallWriter interface {
162+
SyscallConnForWrite() syscall.RawConn
161163
}
162164

163165
func SyscallAvailableForWrite(writer io.Writer) bool {
164166
if _, ok := writer.(syscall.Conn); ok {
165167
return true
166168
}
167-
if _, ok := writer.(SyscallWriteCreator); ok {
169+
if _, ok := writer.(SyscallWriter); ok {
168170
return true
169171
}
170172
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
@@ -179,21 +181,22 @@ func SyscallAvailableForWrite(writer io.Writer) bool {
179181
return false
180182
}
181183

182-
func SyscallConnForWrite(writer io.Writer) syscall.Conn {
184+
func SyscallConnForWrite(writer io.Writer) (SyscallWriter, syscall.RawConn) {
183185
if c, ok := writer.(syscall.Conn); ok {
184-
return c
186+
conn, _ := c.SyscallConn()
187+
return nil, conn
185188
}
186-
if c, ok := writer.(SyscallWriteCreator); ok {
187-
return c.SyscallConnForWrite()
189+
if c, ok := writer.(SyscallWriter); ok {
190+
return c, c.SyscallConnForWrite()
188191
}
189192
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
190-
return nil
193+
return nil, nil
191194
}
192195
if u, ok := writer.(WithUpstreamWriter); ok {
193196
return SyscallConnForWrite(u.UpstreamWriter().(io.Writer))
194197
}
195198
if u, ok := writer.(common.WithUpstream); ok {
196199
return SyscallConnForWrite(u.Upstream().(io.Writer))
197200
}
198-
return nil
201+
return nil, nil
199202
}

0 commit comments

Comments
 (0)