Skip to content

Commit d10d6b7

Browse files
committed
refactor: almost everything
fix: return traffic result when io error occurs feat: allow copy from file
1 parent e54ceb0 commit d10d6b7

File tree

12 files changed

+1479
-507
lines changed

12 files changed

+1479
-507
lines changed

Cargo.toml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "tokio-splice2"
3-
version = "0.3.0-alpha.10"
3+
version = "0.3.0-alpha.11"
44
edition = "2021"
55
rust-version = "1.70.0"
66

@@ -13,17 +13,23 @@ readme = "README.md"
1313
repository = "https://github.com/hanyu-dev/tokio-splice2"
1414

1515
[dependencies]
16-
pin-project-lite = "0.2.16"
1716
rustix = { version = "1.0.0", features = ["pipe"] }
18-
tokio = { version = "1.39.0", features = ["fs", "net"] }
17+
tokio = { version = "1.46.0", features = ["fs", "net"] }
1918
tracing = { version = "0.1.41", optional = true }
2019

2120
[dev-dependencies]
22-
tokio = { version = "1.39.0", features = ["net", "macros", "rt", "signal", "io-util", "time"] }
21+
human-format-next = "0.2.2"
22+
rand = "0.9.1"
23+
tokio = { version = "1.46.0", features = ["net", "macros", "rt", "signal", "io-util", "time"] }
24+
2325

2426
[features]
25-
# Enable tracing support
27+
default = ["feat-tracing"]
28+
29+
# Enable crate `tracing` support
2630
feat-tracing = ["dep:tracing"]
31+
# Enable crate `tracing` support and enable Level::TRACE log in release profile build.
32+
feat-tracing-trace = ["feat-tracing"]
2733

2834
# Enable nightly Rust feature
2935
feat-nightly = []

README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@ See [examples](./examples/).
1313
## Changelog
1414

1515
- 0.3.0:
16+
1617
- [BREAKING] use `rustix` instead of `libc`
17-
- [BREAKING] Do not shutdown stream after transfer finished (let user controls).
1818
- [BREAKING] MSRV is changed to 1.70.0
19-
- Add unidirectional splicing.
20-
- Accept `tokio::fs::File` as input.
19+
- Add tracing log support.
20+
- Add unidirectional copy.
21+
- Add blocking unidirectional copy.
22+
- (Experimental) Add `std::fs::File`/ `tokio::fs::File` support to splice from (like `sendfile`) / to (not fully tested).
23+
- Returns `TrafficResult` instead of `io::Result<T>` to have traffic transferred returned when error occurs (e.g. the sender force closes the stream).
2124

2225
- 0.2.1:
2326
- Fix the maximum value of the `size_t` type. Closes: [https://github.com/Hanaasagi/tokio-splice/issues/2](https://github.com/Hanaasagi/tokio-splice/issues/2).
@@ -28,7 +31,7 @@ See [BENCHMARK](./BENCHMARK.md).
2831

2932
## MSRV
3033

31-
1.70.0
34+
1.70.0 (For running the examples, we recommend using the latest stable Rust version)
3235

3336
## LICENSE
3437

examples/go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module proxy
2+
3+
go 1.21

examples/proxy.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"log"
8+
"net"
9+
"os"
10+
"os/signal"
11+
"sync"
12+
"syscall"
13+
"time"
14+
)
15+
16+
func main() {
17+
fmt.Printf("PID is %d\n", os.Getpid())
18+
19+
ctx, cancel := context.WithCancel(context.Background())
20+
defer cancel()
21+
22+
sigChan := make(chan os.Signal, 1)
23+
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
24+
25+
go func() {
26+
if err := serve(ctx); err != nil {
27+
log.Printf("Serve failed: %v", err)
28+
}
29+
}()
30+
31+
<-sigChan
32+
fmt.Println("Received Ctrl + C, shutting down")
33+
cancel()
34+
35+
time.Sleep(100 * time.Millisecond)
36+
}
37+
38+
func serve(ctx context.Context) error {
39+
listenAddr := os.Getenv("EXAMPLE_LISTEN_ADDR")
40+
if listenAddr == "" {
41+
listenAddr = "0.0.0.0:5201"
42+
}
43+
44+
listener, err := net.Listen("tcp", listenAddr)
45+
if err != nil {
46+
return fmt.Errorf("failed to listen on %s: %w", listenAddr, err)
47+
}
48+
defer listener.Close()
49+
50+
fmt.Printf("Listening on %s\n", listenAddr)
51+
52+
for {
53+
select {
54+
case <-ctx.Done():
55+
return nil
56+
default:
57+
}
58+
59+
if tcpListener, ok := listener.(*net.TCPListener); ok {
60+
tcpListener.SetDeadline(time.Now().Add(100 * time.Millisecond))
61+
}
62+
63+
conn, err := listener.Accept()
64+
if err != nil {
65+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
66+
continue
67+
}
68+
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
69+
log.Printf("Temporary accept error: %v", err)
70+
continue
71+
}
72+
return fmt.Errorf("failed to accept: %w", err)
73+
}
74+
75+
remoteAddr := conn.RemoteAddr()
76+
fmt.Printf("Process incoming connection from %s\n", remoteAddr)
77+
78+
go forwarding(conn)
79+
}
80+
}
81+
82+
func forwarding(stream1 net.Conn) error {
83+
defer stream1.Close()
84+
85+
remoteAddr := os.Getenv("EXAMPLE_REMOTE_ADDR")
86+
if remoteAddr == "" {
87+
remoteAddr = "127.0.0.1:5202"
88+
}
89+
90+
stream2, err := net.Dial("tcp", remoteAddr)
91+
if err != nil {
92+
log.Printf("Failed to connect to remote server: %v", err)
93+
return err
94+
}
95+
defer stream2.Close()
96+
97+
result, err := copyBidirectional(stream1, stream2)
98+
99+
fmt.Printf("Forwarded traffic: %+v\n", result)
100+
101+
if err != nil {
102+
log.Printf("Failed to copy data: %v", err)
103+
return err
104+
}
105+
106+
return nil
107+
}
108+
109+
type TrafficStats struct {
110+
BytesForward uint64
111+
BytesReverse uint64
112+
}
113+
114+
func (t TrafficStats) String() string {
115+
return fmt.Sprintf("TrafficStats { bytes_forward: %d, bytes_reverse: %d }",
116+
t.BytesForward, t.BytesReverse)
117+
}
118+
119+
func copyBidirectional(conn1, conn2 net.Conn) (*TrafficStats, error) {
120+
var stats TrafficStats
121+
var wg sync.WaitGroup
122+
var err1, err2 error
123+
124+
wg.Add(2)
125+
126+
go func() {
127+
defer wg.Done()
128+
defer conn2.Close()
129+
n, err := io.Copy(conn2, conn1)
130+
stats.BytesForward = uint64(n)
131+
err1 = err
132+
}()
133+
134+
go func() {
135+
defer wg.Done()
136+
defer conn1.Close()
137+
n, err := io.Copy(conn1, conn2)
138+
stats.BytesReverse = uint64(n)
139+
err2 = err
140+
}()
141+
142+
wg.Wait()
143+
144+
if err1 != nil && err1 != io.EOF {
145+
return &stats, err1
146+
}
147+
if err2 != nil && err2 != io.EOF {
148+
return &stats, err2
149+
}
150+
151+
return &stats, nil
152+
}

examples/proxy.rs

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ async fn main() -> io::Result<()> {
2323
}
2424

2525
async fn serve() -> io::Result<()> {
26-
let listener = TcpListener::bind("127.0.0.1:8989").await?;
26+
let listener = TcpListener::bind(
27+
env::var("EXAMPLE_LISTEN_ADDR").unwrap_or_else(|_| "0.0.0.0:5201".to_string()),
28+
)
29+
.await?;
2730

2831
loop {
2932
let (incoming, remote_addr) = match listener.accept().await {
@@ -33,7 +36,7 @@ async fn serve() -> io::Result<()> {
3336
continue;
3437
}
3538
Err(e) => {
36-
eprintln!("Failed to accept: {:#?}", e);
39+
eprintln!("Failed to accept: {e:#?}");
3740
break Err(e);
3841
}
3942
};
@@ -45,23 +48,46 @@ async fn serve() -> io::Result<()> {
4548
}
4649

4750
async fn forwarding(mut stream1: TcpStream) -> io::Result<()> {
48-
let mut stream2 = TcpStream::connect(
49-
env::var("EXAMPLE_REMOTE_ADDR").unwrap_or_else(|_| "4.ipw.cn:80".to_string()),
51+
let stream2 = TcpStream::connect(
52+
env::var("EXAMPLE_REMOTE_ADDR").unwrap_or_else(|_| "127.0.0.1:5202".to_string()),
5053
)
51-
.await
52-
.inspect_err(|e| {
53-
eprintln!("Failed to connect to remote server: {e}");
54-
})?;
54+
.await;
55+
56+
let mut stream2 = match stream2 {
57+
Ok(s) => s,
58+
Err(e) => {
59+
eprintln!("Failed to connect to remote server: {e}");
60+
return Err(e);
61+
}
62+
};
63+
64+
// let result = tokio_splice::zero_copy_bidirectional(&mut stream1, &mut
65+
// stream2).await;
66+
let instant = std::time::Instant::now();
67+
let result = tokio_splice2::copy_bidirectional(&mut stream1, &mut stream2).await;
68+
// let result = realm_io::bidi_zero_copy(&mut stream1, &mut stream2).await;
5569

56-
tokio_splice2::copy_bidirectional(&mut stream1, &mut stream2)
57-
.await
58-
.inspect(|(r, w)| {
70+
match result {
71+
Ok(traffic) => {
72+
let total = traffic.sum();
73+
let cost = instant.elapsed();
5974
println!(
60-
"Forwarded {r} bytes from stream1 to stream2, {w} bytes from stream2 to stream1"
75+
"Forwarded traffic: {traffic:?}, total: {}, time: {:.2}s, avg: {}",
76+
human_format_next::Formatter::BINARY
77+
.with_custom_unit("B")
78+
.with_decimals::<4>()
79+
.format(total as f64),
80+
cost.as_secs_f64(),
81+
human_format_next::Formatter::BINARY
82+
.with_custom_unit("B/s")
83+
.with_decimals::<4>()
84+
.format(total as f64 / cost.as_secs_f64())
6185
);
62-
})
63-
.inspect_err(|e| {
86+
Ok(())
87+
}
88+
Err(e) => {
6489
eprintln!("Failed to copy data: {e}");
65-
})?;
66-
Ok(())
90+
Err(e)
91+
}
92+
}
6793
}

0 commit comments

Comments
 (0)