Skip to content

Commit 426ffaf

Browse files
authored
feat:Progress Tracking for Uploads/Downloads
1 parent 2d2e034 commit 426ffaf

File tree

3 files changed

+177
-5
lines changed

3 files changed

+177
-5
lines changed

axios4go_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package axios4go
22

33
import (
4+
"bytes"
45
"encoding/json"
6+
"io"
57
"net/http"
68
"net/http/httptest"
9+
"strings"
710
"testing"
811
)
912

@@ -863,3 +866,55 @@ func TestGetByProxy(t *testing.T) {
863866
}
864867
})
865868
}
869+
870+
func TestProgressCallbacks(t *testing.T) {
871+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
872+
// Read the request body to trigger upload progress
873+
_, err := io.Copy(io.Discard, r.Body)
874+
if err != nil {
875+
t.Fatalf("Failed to read request body: %v", err)
876+
}
877+
878+
// Simulate a large file for download
879+
w.Header().Set("Content-Length", "1000000")
880+
for i := 0; i < 1000000; i++ {
881+
_, err := w.Write([]byte("a"))
882+
if err != nil {
883+
t.Fatalf("Failed to write response: %v", err)
884+
}
885+
}
886+
}))
887+
defer server.Close()
888+
889+
uploadCalled := false
890+
downloadCalled := false
891+
892+
body := bytes.NewReader([]byte(strings.Repeat("b", 500000))) // 500KB upload
893+
894+
_, err := Post(server.URL, body, &RequestOptions{
895+
OnUploadProgress: func(bytesRead, totalBytes int64) {
896+
uploadCalled = true
897+
if bytesRead > totalBytes {
898+
t.Errorf("Upload progress: bytesRead (%d) > totalBytes (%d)", bytesRead, totalBytes)
899+
}
900+
},
901+
OnDownloadProgress: func(bytesRead, totalBytes int64) {
902+
downloadCalled = true
903+
if bytesRead > totalBytes {
904+
t.Errorf("Download progress: bytesRead (%d) > totalBytes (%d)", bytesRead, totalBytes)
905+
}
906+
},
907+
MaxContentLength: 2000000, // Set this to allow our 1MB response
908+
})
909+
910+
if err != nil {
911+
t.Fatalf("Expected no error, got %v", err)
912+
}
913+
914+
if !uploadCalled {
915+
t.Error("Upload progress callback was not called")
916+
}
917+
if !downloadCalled {
918+
t.Error("Download progress callback was not called")
919+
}
920+
}

client.go

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@ type RequestOptions struct {
5454
ResponseType string
5555
ResponseEncoding string
5656
MaxRedirects int
57-
MaxContentLength int
58-
MaxBodyLength int
57+
MaxContentLength int64
58+
MaxBodyLength int64
5959
Decompress bool
6060
ValidateStatus func(int) bool
6161
InterceptorOptions InterceptorOptions
6262
Proxy *Proxy
63+
OnUploadProgress func(bytesRead, totalBytes int64)
64+
OnDownloadProgress func(bytesRead, totalBytes int64)
6365
}
6466

6567
type Proxy struct {
@@ -74,6 +76,38 @@ type Auth struct {
7476
Password string
7577
}
7678

79+
type ProgressReader struct {
80+
reader io.Reader
81+
total int64
82+
read int64
83+
onProgress func(bytesRead, totalBytes int64)
84+
}
85+
86+
type ProgressWriter struct {
87+
writer io.Writer
88+
total int64
89+
written int64
90+
onProgress func(bytesWritten, totalBytes int64)
91+
}
92+
93+
func (pr *ProgressReader) Read(p []byte) (int, error) {
94+
n, err := pr.reader.Read(p)
95+
pr.read += int64(n)
96+
if pr.onProgress != nil {
97+
pr.onProgress(pr.read, pr.total)
98+
}
99+
return n, err
100+
}
101+
102+
func (pw *ProgressWriter) Write(p []byte) (int, error) {
103+
n, err := pw.writer.Write(p)
104+
pw.written += int64(n)
105+
if pw.onProgress != nil {
106+
pw.onProgress(pw.written, pw.total)
107+
}
108+
return n, err
109+
}
110+
77111
var defaultClient = &Client{HTTPClient: &http.Client{}}
78112

79113
func (r *Response) JSON(v interface{}) error {
@@ -345,6 +379,14 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) {
345379
if options.MaxBodyLength > 0 && bodyLength > int64(options.MaxBodyLength) {
346380
return nil, errors.New("request body length exceeded maxBodyLength")
347381
}
382+
383+
if options.Body != nil && options.OnUploadProgress != nil {
384+
bodyReader = &ProgressReader{
385+
reader: bodyReader,
386+
total: bodyLength,
387+
onProgress: options.OnUploadProgress,
388+
}
389+
}
348390
}
349391

350392
req, err := http.NewRequest(options.Method, fullURL, bodyReader)
@@ -428,9 +470,24 @@ func (c *Client) Request(options *RequestOptions) (*Response, error) {
428470
}
429471
}()
430472

431-
responseBody, err := io.ReadAll(resp.Body)
432-
if err != nil {
433-
return nil, err
473+
var responseBody []byte
474+
if options.OnDownloadProgress != nil {
475+
buf := &bytes.Buffer{}
476+
progressWriter := &ProgressWriter{
477+
writer: buf,
478+
total: resp.ContentLength,
479+
onProgress: options.OnDownloadProgress,
480+
}
481+
_, err = io.Copy(progressWriter, resp.Body)
482+
if err != nil {
483+
return nil, err
484+
}
485+
responseBody = buf.Bytes()
486+
} else {
487+
responseBody, err = io.ReadAll(resp.Body)
488+
if err != nil {
489+
return nil, err
490+
}
434491
}
435492

436493
if int64(len(responseBody)) > int64(options.MaxContentLength) {
@@ -504,6 +561,12 @@ func mergeOptions(dst, src *RequestOptions) {
504561
if src.InterceptorOptions.ResponseInterceptors != nil {
505562
dst.InterceptorOptions.ResponseInterceptors = src.InterceptorOptions.ResponseInterceptors
506563
}
564+
if src.OnUploadProgress != nil {
565+
dst.OnUploadProgress = src.OnUploadProgress
566+
}
567+
if src.OnDownloadProgress != nil {
568+
dst.OnDownloadProgress = src.OnDownloadProgress
569+
}
507570
if src.Proxy != nil {
508571
dst.Proxy = src.Proxy
509572
}

examples/download.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"time"
7+
8+
"github.com/rezmoss/axios4go"
9+
)
10+
11+
func main() {
12+
url := "https://ash-speed.hetzner.com/1GB.bin"
13+
outputPath := "1GB.bin"
14+
15+
startTime := time.Now()
16+
lastPrintTime := startTime
17+
18+
resp, err := axios4go.Get(url, &axios4go.RequestOptions{
19+
MaxContentLength: 5 * 1024 * 1024 * 1024, // 5GB
20+
Timeout: 60000 * 5,
21+
OnDownloadProgress: func(bytesRead, totalBytes int64) {
22+
currentTime := time.Now()
23+
if currentTime.Sub(lastPrintTime) >= time.Second || bytesRead == totalBytes {
24+
percentage := float64(bytesRead) / float64(totalBytes) * 100
25+
downloadedMB := float64(bytesRead) / 1024 / 1024
26+
totalMB := float64(totalBytes) / 1024 / 1024
27+
elapsedTime := currentTime.Sub(startTime)
28+
speed := float64(bytesRead) / elapsedTime.Seconds() / 1024 / 1024 // MB/s
29+
30+
fmt.Printf("\rDownloaded %.2f%% (%.2f MB / %.2f MB) - Speed: %.2f MB/s",
31+
percentage, downloadedMB, totalMB, speed)
32+
33+
lastPrintTime = currentTime
34+
}
35+
},
36+
})
37+
38+
if err != nil {
39+
fmt.Printf("\nError downloading file: %v\n", err)
40+
return
41+
}
42+
43+
err = writeResponseToFile(resp, outputPath)
44+
if err != nil {
45+
fmt.Printf("\nError writing file: %v\n", err)
46+
return
47+
}
48+
49+
fmt.Println("\nDownload completed successfully!")
50+
}
51+
52+
func writeResponseToFile(resp *axios4go.Response, outputPath string) error {
53+
return os.WriteFile(outputPath, resp.Body, 0644)
54+
}

0 commit comments

Comments
 (0)