Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2297,3 +2297,49 @@ func Benchmark_Client_Request_Parallel(b *testing.B) {
require.NoError(b, err)
})
}

func Benchmark_Client_Request_Send_ContextCancel(b *testing.B) {
app, ln, start := createHelperServer(b)

startedCh := make(chan struct{})
errCh := make(chan error)
respCh := make(chan *Response)

app.Post("/", func(c fiber.Ctx) error {
startedCh <- struct{}{}
time.Sleep(time.Millisecond) // let cancel be called
return c.Status(fiber.StatusOK).SendString("post")
})

go start()

client := New().SetDial(ln)

b.ReportAllocs()
b.ResetTimer()

for b.Loop() {
ctx, cancel := context.WithCancel(context.Background())

req := AcquireRequest().
SetClient(client).
SetURL("http://example.com").
SetMethod(fiber.MethodPost).
SetContext(ctx)

go func(r *Request) {
defer ReleaseRequest(r)

resp, err := r.Send()

respCh <- resp
errCh <- err
}(req)

<-startedCh // request is made, we can cancel the context now
cancel()

require.Nil(b, <-respCh)
require.ErrorIs(b, <-errCh, ErrTimeoutOrCancel)
}
}
102 changes: 65 additions & 37 deletions client/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/addon/retry"
Expand Down Expand Up @@ -73,25 +72,25 @@ func (c *core) getRetryConfig() *RetryConfig {
// execFunc is the core logic to send the request and receive the response.
// It leverages the fasthttp client, optionally with retries or redirects.
func (c *core) execFunc() (*Response, error) {
resp := AcquireResponse()
resp.setClient(c.client)
resp.setRequest(c.req)
// do not close, these will be returned to the pool
errChan := acquireErrChan()
respChan := acquireResponseChan()

done := int32(0)
errCh, reqv := acquireErrChan(), fasthttp.AcquireRequest()
defer releaseErrChan(errCh)

c.req.RawRequest.CopyTo(reqv)
cfg := c.getRetryConfig()

var err error
go func() {
// retain both channels until they are drained
defer releaseErrChan(errChan)
defer releaseResponseChan(respChan)

reqv := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(reqv)

respv := fasthttp.AcquireResponse()
defer func() {
fasthttp.ReleaseRequest(reqv)
fasthttp.ReleaseResponse(respv)
}()
defer fasthttp.ReleaseResponse(respv)

c.req.RawRequest.CopyTo(reqv)

var err error
if cfg != nil {
// Use an exponential backoff retry strategy.
err = retry.NewExponentialBackoff(*cfg).Retry(func() error {
Expand All @@ -108,27 +107,31 @@ func (c *core) execFunc() (*Response, error) {
}
}

if atomic.CompareAndSwapInt32(&done, 0, 1) {
if err != nil {
errCh <- err
return
}
respv.CopyTo(resp.RawResponse)
errCh <- nil
if err != nil {
errChan <- err
return
}

resp := AcquireResponse()
resp.setClient(c.client)
resp.setRequest(c.req)
respv.CopyTo(resp.RawResponse)
respChan <- resp
}()

select {
case err := <-errCh:
if err != nil {
// Release the response if an error occurs.
ReleaseResponse(resp)
return nil, err
}
case err := <-errChan:
return nil, err
case resp := <-respChan:
return resp, nil
case <-c.ctx.Done():
atomic.SwapInt32(&done, 1)
ReleaseResponse(resp)
go func() { // drain the channels and release the response
select {
case resp := <-respChan:
ReleaseResponse(resp)
case <-errChan:
}
}()
return nil, ErrTimeoutOrCancel
}
}
Expand Down Expand Up @@ -219,16 +222,39 @@ func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Resp
return resp, nil
}

var responseChanPool = &sync.Pool{
New: func() any {
return make(chan *Response)
},
}

// acquireResponseChan returns an empty, non-closed *Response channel from the pool.
// The returned channel may be returned to the pool with releaseResponseChan
func acquireResponseChan() chan *Response {
ch, ok := responseChanPool.Get().(chan *Response)
if !ok {
panic(errors.New("failed to type-assert to *Response"))
}
return ch
}

// releaseResponseChan returns the *Response channel to the pool.
// It's the caller's responsibility to ensure that:
// - the channel is not closed
// - the channel is drained before returning it
// - the channel is not reused after returning it
func releaseResponseChan(ch chan *Response) {
responseChanPool.Put(ch)
}

var errChanPool = &sync.Pool{
New: func() any {
return make(chan error, 1)
return make(chan error)
},
}

// acquireErrChan returns an empty error channel from the pool.
//
// The returned channel may be returned to the pool with releaseErrChan when no longer needed,
// reducing GC load.
// acquireErrChan returns an empty, non-closed error channel from the pool.
// The returned channel may be returned to the pool with releaseErrChan
func acquireErrChan() chan error {
ch, ok := errChanPool.Get().(chan error)
if !ok {
Expand All @@ -238,8 +264,10 @@ func acquireErrChan() chan error {
}

// releaseErrChan returns the error channel to the pool.
//
// Do not use the released channel afterward to avoid data races.
// It's caller's responsibility to ensure that:
// - the channel is not closed
// - the channel is drained before returning it
// - the channel is not reused after returning it
func releaseErrChan(ch chan error) {
errChanPool.Put(ch)
}
Expand Down
140 changes: 140 additions & 0 deletions client/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package client

import (
"context"
"crypto/tls"
"errors"
"net"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"

"github.com/gofiber/fiber/v3"
Expand Down Expand Up @@ -70,6 +73,10 @@ func Test_Exec_Func(t *testing.T) {
return errors.New("the request is error")
})

app.Get("/redirect", func(c fiber.Ctx) error {
return c.Redirect().Status(fiber.StatusFound).To("/normal")
})

app.Get("/hang-up", func(c fiber.Ctx) error {
time.Sleep(time.Second)
return c.SendString(c.Hostname() + " hang up")
Expand Down Expand Up @@ -97,6 +104,25 @@ func Test_Exec_Func(t *testing.T) {
require.Equal(t, "example.com", string(resp.RawResponse.Body()))
})

t.Run("follow redirect with retry config", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), New(), AcquireRequest()
core.ctx = context.Background()
core.client = client
core.req = req

client.SetRetryConfig(&RetryConfig{MaxRetryCount: 1})
client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() })
req.SetMaxRedirects(1)
req.RawRequest.Header.SetMethod(fiber.MethodGet)
req.RawRequest.SetRequestURI("http://example.com/redirect")

resp, err := core.execFunc()
require.NoError(t, err)
require.Equal(t, 200, resp.RawResponse.StatusCode())
require.Equal(t, "example.com", string(resp.RawResponse.Body()))
})

t.Run("the request return an error", func(t *testing.T) {
t.Parallel()
core, client, req := newCore(), New(), AcquireRequest()
Expand Down Expand Up @@ -131,6 +157,59 @@ func Test_Exec_Func(t *testing.T) {

require.Equal(t, ErrTimeoutOrCancel, err)
})

t.Run("cancel drains errChan", func(t *testing.T) {
core, client, req := newCore(), New(), AcquireRequest()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

core.ctx = ctx
core.client = client
core.req = req

req.RawRequest.SetRequestURI("http://example.com/drain-err")

blockingTransport := newBlockingErrTransport(errors.New("upstream failure"))
client.transport = blockingTransport
defer blockingTransport.release()

type execResult struct {
resp *Response
err error
}

resultCh := make(chan execResult, 1)
go func() {
resp, err := core.execFunc()
resultCh <- execResult{resp: resp, err: err}
}()

select {
case <-blockingTransport.called:
case <-time.After(time.Second):
t.Fatal("transport Do was not invoked")
}

cancel()

var result execResult
select {
case result = <-resultCh:
case <-time.After(time.Second):
t.Fatal("execFunc did not return")
}

require.Nil(t, result.resp)
require.ErrorIs(t, result.err, ErrTimeoutOrCancel)

blockingTransport.release()

select {
case <-blockingTransport.finished:
case <-time.After(time.Second):
t.Fatal("transport Do did not finish")
}
})
}

func Test_Execute(t *testing.T) {
Expand Down Expand Up @@ -246,3 +325,64 @@ func Test_Execute(t *testing.T) {
require.Equal(t, "example.com hang up", string(resp.RawResponse.Body()))
})
}

type blockingErrTransport struct {
err error

called chan struct{}
unblock chan struct{}
finished chan struct{}

calledOnce sync.Once
releaseOnce sync.Once
finishedOnce sync.Once
}

func newBlockingErrTransport(err error) *blockingErrTransport {
return &blockingErrTransport{
err: err,
called: make(chan struct{}),
unblock: make(chan struct{}),
finished: make(chan struct{}),
}
}

func (b *blockingErrTransport) Do(_ *fasthttp.Request, _ *fasthttp.Response) error {
b.calledOnce.Do(func() { close(b.called) })
<-b.unblock
b.finishedOnce.Do(func() { close(b.finished) })
return b.err
}

func (b *blockingErrTransport) DoTimeout(req *fasthttp.Request, resp *fasthttp.Response, _ time.Duration) error {
return b.Do(req, resp)
}

func (b *blockingErrTransport) DoDeadline(req *fasthttp.Request, resp *fasthttp.Response, _ time.Time) error {
return b.Do(req, resp)
}

func (b *blockingErrTransport) DoRedirects(req *fasthttp.Request, resp *fasthttp.Response, _ int) error {
return b.Do(req, resp)
}

func (*blockingErrTransport) CloseIdleConnections() {
}

func (*blockingErrTransport) TLSConfig() *tls.Config {
return nil
}

func (*blockingErrTransport) SetTLSConfig(_ *tls.Config) {
}

func (*blockingErrTransport) SetDial(_ fasthttp.DialFunc) {
}

func (*blockingErrTransport) Client() any {
return nil
}

func (b *blockingErrTransport) release() {
b.releaseOnce.Do(func() { close(b.unblock) })
}
Loading