Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 119 additions & 38 deletions client/cookiejar.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"bytes"
"errors"
"net"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -50,18 +51,18 @@
return nil
}

return cj.getByHostAndPath(uri.Host(), uri.Path())
secure := bytes.Equal(uri.Scheme(), []byte("https"))
return cj.getByHostAndPath(uri.Host(), uri.Path(), secure)
}

// getByHostAndPath returns cookies stored for a specific host and path.
func (cj *CookieJar) getByHostAndPath(host, path []byte) []*fasthttp.Cookie {
func (cj *CookieJar) getByHostAndPath(host, path []byte, secure bool) []*fasthttp.Cookie {
if cj.hostCookies == nil {
return nil
}

var (
err error
cookies []*fasthttp.Cookie
hostStr = utils.UnsafeString(host)
)

Expand All @@ -70,19 +71,7 @@
if err != nil {
hostStr = utils.UnsafeString(host)
}
// get cookies deleting expired ones
cookies = cj.getCookiesByHost(hostStr)

newCookies := make([]*fasthttp.Cookie, 0, len(cookies))
for i := 0; i < len(cookies); i++ {
cookie := cookies[i]
if len(path) > 1 && len(cookie.Path()) > 1 && !bytes.HasPrefix(cookie.Path(), path) {
continue
}
newCookies = append(newCookies, cookie)
}

return newCookies
return cj.cookiesForRequest(hostStr, path, secure)
}

// getCookiesByHost returns cookies stored for a specific host, removing any that have expired.
Expand All @@ -106,6 +95,42 @@
return cookies
}

// cookiesForRequest returns cookies that match the given host, path and security settings.
func (cj *CookieJar) cookiesForRequest(host string, path []byte, secure bool) []*fasthttp.Cookie {

Check failure on line 99 in client/cookiejar.go

View workflow job for this annotation

GitHub Actions / lint

flag-parameter: parameter 'secure' seems to be a control flag, avoid control coupling (revive)
cj.mu.Lock()
defer cj.mu.Unlock()

now := time.Now()
var matched []*fasthttp.Cookie

for domain, cookies := range cj.hostCookies {
if !domainMatch(host, domain) {
continue
}

kept := cookies[:0]
for _, c := range cookies {
if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) {
fasthttp.ReleaseCookie(c)
continue
}
kept = append(kept, c)
if len(path) > 1 && len(c.Path()) > 1 && !bytes.HasPrefix(path, c.Path()) {
continue
}
if c.Secure() && !secure {
continue
}
nc := fasthttp.AcquireCookie()
nc.CopyTo(c)
matched = append(matched, nc)
}
cj.hostCookies[domain] = kept
}

return matched
}

// Set stores the given cookies for the specified URI host. If a cookie key already exists,
// it will be replaced by the new cookie value.
//
Expand All @@ -123,6 +148,10 @@
// CookieJar stores copies of the provided cookies, so they may be safely released after use.
func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) {
hostStr := utils.UnsafeString(host)
if h, _, err := net.SplitHostPort(hostStr); err == nil {
hostStr = h
}
hostStr = utils.ToLower(hostStr)

cj.mu.Lock()
defer cj.mu.Unlock()
Expand All @@ -131,22 +160,30 @@
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}

hostCookies, ok := cj.hostCookies[hostStr]
if !ok {
// If the key does not exist in the map, make a copy to avoid unsafe usage.
hostStr = string(host)
}

for _, cookie := range cookies {
domain := utils.TrimLeft(cookie.Domain(), '.')
utils.ToLowerBytes(domain)
key := hostStr
if len(domain) == 0 {
cookie.SetDomain(hostStr)
} else {
key = utils.UnsafeString(domain)
cookie.SetDomainBytes(domain)
}

hostCookies, ok := cj.hostCookies[key]
if !ok {
key = string([]byte(key))
}

existing := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies)
if existing == nil {
// If the cookie does not exist, acquire a new one.
existing = fasthttp.AcquireCookie()
hostCookies = append(hostCookies, existing)
}
existing.CopyTo(cookie) // Override cookie properties.
existing.CopyTo(cookie)
cj.hostCookies[key] = hostCookies
}
cj.hostCookies[hostStr] = hostCookies
}

// SetKeyValue sets a cookie for the specified host with the given key and value.
Expand Down Expand Up @@ -174,15 +211,21 @@
// dumpCookiesToReq writes the stored cookies to the given request.
func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) {
uri := req.URI()
cookies := cj.getByHostAndPath(uri.Host(), uri.Path())
secure := bytes.Equal(uri.Scheme(), []byte("https"))
cookies := cj.getByHostAndPath(uri.Host(), uri.Path(), secure)
for _, cookie := range cookies {
req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value())
fasthttp.ReleaseCookie(cookie)
}
}

// parseCookiesFromResp parses the cookies from the response and stores them for the specified host and path.
func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Response) {

Check failure on line 223 in client/cookiejar.go

View workflow job for this annotation

GitHub Actions / lint

unused-parameter: parameter 'path' seems to be unused, consider removing or renaming it as _ (revive)
hostStr := utils.UnsafeString(host)
if h, _, err := net.SplitHostPort(hostStr); err == nil {
hostStr = h
}

Check warning on line 227 in client/cookiejar.go

View check run for this annotation

Codecov / codecov/patch

client/cookiejar.go#L226-L227

Added lines #L226 - L227 were not covered by tests
hostStr = utils.ToLower(hostStr)

cj.mu.Lock()
defer cj.mu.Unlock()
Expand All @@ -191,28 +234,55 @@
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}

cookies, ok := cj.hostCookies[hostStr]
if !ok {
// If the key does not exist in the map, make a copy to avoid unsafe usage.
hostStr = string(host)
if _, ok := cj.hostCookies[hostStr]; !ok {
hostStr = string([]byte(hostStr))
}

now := time.Now()
for key, value := range resp.Header.Cookies() {
for _, value := range resp.Header.Cookies() {
tmp := fasthttp.AcquireCookie()
_ = tmp.ParseBytes(value) //nolint:errcheck // ignore error

domainBytes := utils.TrimLeft(tmp.Domain(), '.')
utils.ToLowerBytes(domainBytes)
key := hostStr
if len(domainBytes) == 0 {
tmp.SetDomain(hostStr)
} else {
key = utils.UnsafeString(domainBytes)
tmp.SetDomainBytes(domainBytes)
}

Check warning on line 254 in client/cookiejar.go

View check run for this annotation

Codecov / codecov/patch

client/cookiejar.go#L252-L254

Added lines #L252 - L254 were not covered by tests

if _, ok := cj.hostCookies[key]; !ok {
key = string([]byte(key))
}

cookies := cj.hostCookies[key]
c := searchCookieByKeyAndPath(tmp.Key(), tmp.Path(), cookies)
created := false
c := searchCookieByKeyAndPath(key, path, cookies)
if c == nil {
c, created = fasthttp.AcquireCookie(), true
cookies = append(cookies, c)
}

_ = c.ParseBytes(value) //nolint:errcheck // ignore error
c.CopyTo(tmp)
if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) {
cookies = append(cookies, c)
} else if created {
fasthttp.ReleaseCookie(c)
cj.hostCookies[key] = cookies
} else {
// remove expired cookie from slice
for i := 0; i < len(cookies); i++ {
if cookies[i] == c {
cookies = append(cookies[:i], cookies[i+1:]...)
break
}
}
cj.hostCookies[key] = cookies
if created {
fasthttp.ReleaseCookie(c)
}
}
fasthttp.ReleaseCookie(tmp)
}
cj.hostCookies[hostStr] = cookies
}

// Release releases all stored cookies. After this, the CookieJar is empty.
Expand All @@ -232,10 +302,21 @@
func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie {
for _, c := range cookies {
if bytes.Equal(key, c.Key()) {
if len(path) <= 1 || bytes.HasPrefix(c.Path(), path) {
if len(path) <= 1 || bytes.Equal(c.Path(), path) {
return c
}
}
}
return nil
}

// domainMatch reports whether host domain-matches the given cookie domain.
func domainMatch(host, domain string) bool {
host = utils.ToLower(host)
domain = utils.UnsafeString(utils.TrimLeft(utils.UnsafeBytes(domain), '.'))
domain = utils.ToLower(domain)
if host == domain {
return true
}
return strings.HasSuffix(host, "."+domain)
}
99 changes: 89 additions & 10 deletions client/cookiejar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@ func Test_CookieJarGet(t *testing.T) {
t.Parallel()

url := []byte("http://fasthttp.com/")
url1 := []byte("http://fasthttp.com/make")
url1 := []byte("http://fasthttp.com/make/")
url11 := []byte("http://fasthttp.com/hola")
url2 := []byte("http://fasthttp.com/make/fasthttp")
url3 := []byte("http://fasthttp.com/make/fasthttp/great")
prefix := []byte("/")
prefix1 := []byte("/make")
prefix2 := []byte("/make/fasthttp")
prefix3 := []byte("/make/fasthttp/great")
cj := &CookieJar{}

c1 := &fasthttp.Cookie{}
Expand Down Expand Up @@ -69,9 +66,9 @@ func Test_CookieJarGet(t *testing.T) {
cj.Set(uri1, c1, c2, c3)

cookies := cj.Get(uri1)
require.Len(t, cookies, 3)
require.Len(t, cookies, 1)
for _, cookie := range cookies {
require.True(t, bytes.HasPrefix(cookie.Path(), prefix1))
require.True(t, bytes.HasPrefix(uri1.Path(), cookie.Path()))
}

cookies = cj.Get(uri11)
Expand All @@ -80,14 +77,13 @@ func Test_CookieJarGet(t *testing.T) {
cookies = cj.Get(uri2)
require.Len(t, cookies, 2)
for _, cookie := range cookies {
require.True(t, bytes.HasPrefix(cookie.Path(), prefix2))
require.True(t, bytes.HasPrefix(uri2.Path(), cookie.Path()))
}

cookies = cj.Get(uri3)
require.Len(t, cookies, 1)

require.Len(t, cookies, 3)
for _, cookie := range cookies {
require.True(t, bytes.HasPrefix(cookie.Path(), prefix3))
require.True(t, bytes.HasPrefix(uri3.Path(), cookie.Path()))
}

cookies = cj.Get(uri)
Expand Down Expand Up @@ -209,4 +205,87 @@ func Test_CookieJarGetFromResponse(t *testing.T) {

cookies := cj.Get(uri)
require.Len(t, cookies, 3)
values := map[string]string{"key": "val", "k": "v", "kk": "vv"}
for _, c := range cookies {
k := string(c.Key())
v, ok := values[k]
require.True(t, ok)
require.Equal(t, v, string(c.Value()))
delete(values, k)
}
require.Empty(t, values)
}

func Test_CookieJar_HostPort(t *testing.T) {
t.Parallel()

jar := &CookieJar{}
uriSet := fasthttp.AcquireURI()
require.NoError(t, uriSet.Parse(nil, []byte("http://fasthttp.com:80/path")))

c := &fasthttp.Cookie{}
c.SetKey("k")
c.SetValue("v")
jar.Set(uriSet, c)

// retrieve using a different port to ensure port is ignored
uriGet := fasthttp.AcquireURI()
require.NoError(t, uriGet.Parse(nil, []byte("http://fasthttp.com:8080/path")))

cookies := jar.Get(uriGet)
require.Len(t, cookies, 1)
require.Equal(t, "k", string(cookies[0].Key()))
require.Equal(t, "v", string(cookies[0].Value()))
require.Equal(t, "fasthttp.com", string(cookies[0].Domain()))
}

func Test_CookieJar_Domain(t *testing.T) {
t.Parallel()

jar := &CookieJar{}

uri := fasthttp.AcquireURI()
require.NoError(t, uri.Parse(nil, []byte("http://sub.example.com/")))

c := &fasthttp.Cookie{}
c.SetKey("k")
c.SetValue("v")
c.SetDomain("example.com")

jar.Set(uri, c)

uri2 := fasthttp.AcquireURI()
require.NoError(t, uri2.Parse(nil, []byte("http://other.example.com/")))

cookies := jar.Get(uri2)
require.Len(t, cookies, 1)
require.Equal(t, "k", string(cookies[0].Key()))
require.Equal(t, "v", string(cookies[0].Value()))
}

func Test_CookieJar_Secure(t *testing.T) {
t.Parallel()

jar := &CookieJar{}

uriHTTP := fasthttp.AcquireURI()
require.NoError(t, uriHTTP.Parse(nil, []byte("http://example.com/")))

c := &fasthttp.Cookie{}
c.SetKey("k")
c.SetValue("v")
c.SetSecure(true)

jar.Set(uriHTTP, c)

cookies := jar.Get(uriHTTP)
require.Empty(t, cookies)

uriHTTPS := fasthttp.AcquireURI()
require.NoError(t, uriHTTPS.Parse(nil, []byte("https://example.com/")))

cookies = jar.Get(uriHTTPS)
require.Len(t, cookies, 1)
require.Equal(t, "k", string(cookies[0].Key()))
require.Equal(t, "v", string(cookies[0].Value()))
}
Loading