Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 252 additions & 5 deletions middleware/idempotency/idempotency_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package idempotency_test
package idempotency

import (
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"sync"
Expand All @@ -11,13 +13,14 @@ import (
"time"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/idempotency"
"github.com/valyala/fasthttp"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const validKey = "00000000-0000-0000-0000-000000000000"

// go test -run Test_Idempotency
func Test_Idempotency(t *testing.T) {
t.Parallel()
Expand All @@ -29,7 +32,7 @@ func Test_Idempotency(t *testing.T) {
}

isMethodSafe := fiber.IsMethodSafe(c.Method())
isIdempotent := idempotency.IsFromCache(c) || idempotency.WasPutToCache(c)
isIdempotent := IsFromCache(c) || WasPutToCache(c)
hasReqHeader := c.Get("X-Idempotency-Key") != ""

if isMethodSafe {
Expand All @@ -53,7 +56,7 @@ func Test_Idempotency(t *testing.T) {
// Needs to be at least a second as the memory storage doesn't support shorter durations.
const lifetime = 2 * time.Second

app.Use(idempotency.New(idempotency.Config{
app.Use(New(Config{
Lifetime: lifetime,
}))

Expand Down Expand Up @@ -136,7 +139,7 @@ func Benchmark_Idempotency(b *testing.B) {
// Needs to be at least a second as the memory storage doesn't support shorter durations.
const lifetime = 1 * time.Second

app.Use(idempotency.New(idempotency.Config{
app.Use(New(Config{
Lifetime: lifetime,
}))

Expand Down Expand Up @@ -169,3 +172,247 @@ func Benchmark_Idempotency(b *testing.B) {
}
})
}

func Test_configDefault_defaults(t *testing.T) {
t.Parallel()

cfg := configDefault()
require.NotNil(t, cfg.Lock)
require.NotNil(t, cfg.Storage)
require.Equal(t, ConfigDefault.Lifetime, cfg.Lifetime)
require.Equal(t, ConfigDefault.KeyHeader, cfg.KeyHeader)
require.Nil(t, cfg.KeepResponseHeaders)

app := fiber.New()

fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
ctx := app.AcquireCtx(fctx)
require.True(t, cfg.Next(ctx))
app.ReleaseCtx(ctx)

fctx = &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodPost)
ctx = app.AcquireCtx(fctx)
require.False(t, cfg.Next(ctx))
app.ReleaseCtx(ctx)

require.NoError(t, cfg.KeyHeaderValidate(validKey))
require.Error(t, cfg.KeyHeaderValidate("short"))
}

func Test_configDefault_override(t *testing.T) {
t.Parallel()

l := &stubLock{}
s := &stubStorage{}

cfg := configDefault(Config{
Lifetime: 42 * time.Second,
KeyHeader: "Foo",
KeepResponseHeaders: []string{},
Lock: l,
Storage: s,
})

require.Equal(t, 42*time.Second, cfg.Lifetime)
require.Equal(t, "Foo", cfg.KeyHeader)
require.Nil(t, cfg.KeepResponseHeaders)
require.Equal(t, l, cfg.Lock)
require.Equal(t, s, cfg.Storage)
require.NotNil(t, cfg.Next)
require.NotNil(t, cfg.KeyHeaderValidate)
}

// helper to perform request
func do(app *fiber.App, req *http.Request) (*http.Response, string) {
resp, err := app.Test(req, fiber.TestConfig{Timeout: 5 * time.Second})
if err != nil {
panic(err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return resp, string(body)
}

func Test_New_NextSkip(t *testing.T) {
t.Parallel()
app := fiber.New()
var count int

app.Use(New(Config{Next: func(_ fiber.Ctx) bool { return true }}))

app.Post("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
_, body1 := do(app, req)

req2 := httptest.NewRequest(http.MethodPost, "/", nil)
req2.Header.Set(ConfigDefault.KeyHeader, validKey)
_, body2 := do(app, req2)

require.Equal(t, "1", body1)
require.Equal(t, "2", body2)
}

func Test_New_InvalidKey(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Post("/", func(_ fiber.Ctx) error { return nil })

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, "bad")
resp, body := do(app, req)

require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "invalid length")
}

func Test_New_StorageGetError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{getErr: errors.New("boom")}
app.Use(New(Config{Storage: s, Lock: &stubLock{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)

require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to write cached response at fastpath")
}

func Test_New_UnmarshalError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{data: map[string][]byte{validKey: []byte("bad")}}
app.Use(New(Config{Storage: s, Lock: &stubLock{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)

require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to write cached response at fastpath")
}

func Test_New_StoreRetrieve_FilterHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{}
app.Use(New(Config{
Storage: s,
Lock: &stubLock{},
KeepResponseHeaders: []string{"Foo"},
}))

var count int
app.Post("/", func(c fiber.Ctx) error {
count++
c.Set("Foo", "foo")
c.Set("Bar", "bar")
return c.SendString(fmt.Sprintf("resp%d", count))
})

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, "resp1", body)
require.Equal(t, "foo", resp.Header.Get("Foo"))
require.Equal(t, "bar", resp.Header.Get("Bar"))

req2 := httptest.NewRequest(http.MethodPost, "/", nil)
req2.Header.Set(ConfigDefault.KeyHeader, validKey)
resp2, body2 := do(app, req2)
require.Equal(t, "resp1", body2)
require.Equal(t, "foo", resp2.Header.Get("Foo"))
require.Empty(t, resp2.Header.Get("Bar"))
require.Equal(t, 1, count)
require.Equal(t, 1, s.setCount)
}

func Test_New_HandlerError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{}
app.Use(New(Config{Storage: s, Lock: &stubLock{}}))
app.Post("/", func(_ fiber.Ctx) error { return errors.New("boom") })

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Equal(t, "boom", body)
require.Equal(t, 0, s.setCount)

resp2, body2 := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp2.StatusCode)
require.Equal(t, "boom", body2)
require.Equal(t, 0, s.setCount)
}

func Test_New_LockError(t *testing.T) {
t.Parallel()
app := fiber.New()
l := &stubLock{lockErr: errors.New("fail")}
app.Use(New(Config{Lock: l, Storage: &stubStorage{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to lock")
}

func Test_New_StorageSetError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{setErr: errors.New("nope")}
app.Use(New(Config{Storage: s, Lock: &stubLock{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to save response")
}

func Test_New_UnlockError(t *testing.T) {
t.Parallel()
app := fiber.New()
l := &stubLock{unlockErr: errors.New("u")}
app.Use(New(Config{Lock: l, Storage: &stubStorage{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "ok", body)
}

func Test_New_SecondPassReadError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{}
l := &stubLock{afterLock: func() { s.getErr = errors.New("g") }}
app.Use(New(Config{Lock: l, Storage: s}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })

req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to write cached response while locked")
}
64 changes: 64 additions & 0 deletions middleware/idempotency/stub_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package idempotency

import (
"time"
)

// stubLock implements Locker for testing purposes.
type stubLock struct {
lockErr error
unlockErr error
afterLock func()
}

func (s *stubLock) Lock(string) error {
if s.afterLock != nil {
s.afterLock()
}
return s.lockErr
}
func (s *stubLock) Unlock(string) error { return s.unlockErr }

// stubStorage implements fiber.Storage for testing.
type stubStorage struct {
data map[string][]byte
getErr error
setErr error
setCount int
}

func (s *stubStorage) Get(key string) ([]byte, error) {
if s.getErr != nil {
return nil, s.getErr
}
if s.data == nil {
return nil, nil
}
return s.data[key], nil
}

func (s *stubStorage) Set(key string, val []byte, _ time.Duration) error {
if s.setErr != nil {
return s.setErr
}
if s.data == nil {
s.data = make(map[string][]byte)
}
s.data[key] = val
s.setCount++
return nil
}

func (s *stubStorage) Delete(key string) error {
if s.data != nil {
delete(s.data, key)
}
return nil
}

func (s *stubStorage) Reset() error {
s.data = make(map[string][]byte)
return nil
}

func (_ *stubStorage) Close() error { return nil }

Check failure on line 64 in middleware/idempotency/stub_test.go

View workflow job for this annotation

GitHub Actions / lint

receiver-naming: receiver name should not be an underscore, omit the name if it is unused (revive)
Loading