Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 140 additions & 47 deletions client/cookiejar.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"errors"
"net"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -50,18 +51,18 @@ func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie {
return nil
}

return cj.getByHostAndPath(uri.Host(), uri.Path())
secure := bytes.Equal(uri.Scheme(), []byte("https"))
return cj.getByHostAndPath(uri.Host(), uri.Path(), secure)
}

// getByHostAndPath returns cookies stored for a specific host and path.
func (cj *CookieJar) getByHostAndPath(host, path []byte) []*fasthttp.Cookie {
func (cj *CookieJar) getByHostAndPath(host, path []byte, secure bool) []*fasthttp.Cookie {
if cj.hostCookies == nil {
return nil
}

var (
err error
cookies []*fasthttp.Cookie
hostStr = utils.UnsafeString(host)
)

Expand All @@ -70,19 +71,7 @@ func (cj *CookieJar) getByHostAndPath(host, path []byte) []*fasthttp.Cookie {
if err != nil {
hostStr = utils.UnsafeString(host)
}
// get cookies deleting expired ones
cookies = cj.getCookiesByHost(hostStr)

newCookies := make([]*fasthttp.Cookie, 0, len(cookies))
for i := 0; i < len(cookies); i++ {
cookie := cookies[i]
if len(path) > 1 && len(cookie.Path()) > 1 && !bytes.HasPrefix(cookie.Path(), path) {
continue
}
newCookies = append(newCookies, cookie)
}

return newCookies
return cj.cookiesForRequest(hostStr, path, secure)
}

// getCookiesByHost returns cookies stored for a specific host, removing any that have expired.
Expand All @@ -93,17 +82,57 @@ func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie {
now := time.Now()
cookies := cj.hostCookies[host]

for i := 0; i < len(cookies); i++ {
c := cookies[i]
kept := cookies[:0]
for _, c := range cookies {
// Remove expired cookies.
if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) {
cookies = append(cookies[:i], cookies[i+1:]...)
fasthttp.ReleaseCookie(c)
i--
continue
}
kept = append(kept, c)
}
cj.hostCookies[host] = kept

return kept
}

// cookiesForRequest returns cookies that match the given host, path and security settings.
//
//nolint:revive // secure is required to filter Secure cookies based on scheme
func (cj *CookieJar) cookiesForRequest(host string, path []byte, secure bool) []*fasthttp.Cookie {
cj.mu.Lock()
defer cj.mu.Unlock()

now := time.Now()
var matched []*fasthttp.Cookie

for domain, cookies := range cj.hostCookies {
if !domainMatch(host, domain) {
continue
}

kept := cookies[:0]
for _, c := range cookies {
if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) {
fasthttp.ReleaseCookie(c)
continue
}
kept = append(kept, c)

if !pathMatch(path, c.Path()) {
continue
}
if c.Secure() && !secure {
continue
}
nc := fasthttp.AcquireCookie()
nc.CopyTo(c)
matched = append(matched, nc)
}
cj.hostCookies[domain] = kept
}

return cookies
return matched
}

// Set stores the given cookies for the specified URI host. If a cookie key already exists,
Expand All @@ -123,6 +152,11 @@ func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) {
// CookieJar stores copies of the provided cookies, so they may be safely released after use.
func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) {
hostStr := utils.UnsafeString(host)
if h, _, err := net.SplitHostPort(hostStr); err == nil {
hostStr = h
}
hostStr = utils.ToLower(hostStr)
hostKey := utils.CopyString(hostStr)

cj.mu.Lock()
defer cj.mu.Unlock()
Expand All @@ -131,22 +165,27 @@ func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) {
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}

hostCookies, ok := cj.hostCookies[hostStr]
if !ok {
// If the key does not exist in the map, make a copy to avoid unsafe usage.
hostStr = string(host)
}

for _, cookie := range cookies {
domain := utils.TrimLeft(cookie.Domain(), '.')
utils.ToLowerBytes(domain)
key := hostKey
if len(domain) == 0 {
cookie.SetDomain(hostStr)
} else {
key = utils.CopyString(utils.UnsafeString(domain))
cookie.SetDomainBytes(domain)
}

hostCookies := cj.hostCookies[key]

existing := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies)
if existing == nil {
// If the cookie does not exist, acquire a new one.
existing = fasthttp.AcquireCookie()
hostCookies = append(hostCookies, existing)
}
existing.CopyTo(cookie) // Override cookie properties.
existing.CopyTo(cookie)
cj.hostCookies[key] = hostCookies
}
cj.hostCookies[hostStr] = hostCookies
}

// SetKeyValue sets a cookie for the specified host with the given key and value.
Expand Down Expand Up @@ -174,15 +213,22 @@ func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) {
// dumpCookiesToReq writes the stored cookies to the given request.
func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) {
uri := req.URI()
cookies := cj.getByHostAndPath(uri.Host(), uri.Path())
secure := bytes.Equal(uri.Scheme(), []byte("https"))
cookies := cj.getByHostAndPath(uri.Host(), uri.Path(), secure)
for _, cookie := range cookies {
req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value())
fasthttp.ReleaseCookie(cookie)
}
}

// parseCookiesFromResp parses the cookies from the response and stores them for the specified host and path.
func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Response) {
func (cj *CookieJar) parseCookiesFromResp(host, _ []byte, resp *fasthttp.Response) {
hostStr := utils.UnsafeString(host)
if h, _, err := net.SplitHostPort(hostStr); err == nil {
hostStr = h
}
hostStr = utils.ToLower(hostStr)
hostKey := utils.CopyString(hostStr)

cj.mu.Lock()
defer cj.mu.Unlock()
Expand All @@ -191,28 +237,43 @@ func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Resp
cj.hostCookies = make(map[string][]*fasthttp.Cookie)
}

cookies, ok := cj.hostCookies[hostStr]
if !ok {
// If the key does not exist in the map, make a copy to avoid unsafe usage.
hostStr = string(host)
}

now := time.Now()
for key, value := range resp.Header.Cookies() {
created := false
c := searchCookieByKeyAndPath(key, path, cookies)
for _, value := range resp.Header.Cookies() {
tmp := fasthttp.AcquireCookie()
_ = tmp.ParseBytes(value) //nolint:errcheck // ignore error

domainBytes := utils.TrimLeft(tmp.Domain(), '.')
utils.ToLowerBytes(domainBytes)
key := hostKey
if len(domainBytes) == 0 {
tmp.SetDomain(hostStr)
} else {
key = utils.CopyString(utils.UnsafeString(domainBytes))
tmp.SetDomainBytes(domainBytes)
}

cookies := cj.hostCookies[key]
c := searchCookieByKeyAndPath(tmp.Key(), tmp.Path(), cookies)
if c == nil {
c, created = fasthttp.AcquireCookie(), true
c = fasthttp.AcquireCookie()
cookies = append(cookies, c)
}

_ = c.ParseBytes(value) //nolint:errcheck // ignore error
c.CopyTo(tmp)
if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) {
cookies = append(cookies, c)
} else if created {
cj.hostCookies[key] = cookies
} else {
kept := cookies[:0]
for _, v := range cookies {
if v != c {
kept = append(kept, v)
}
}
cj.hostCookies[key] = kept
fasthttp.ReleaseCookie(c)
}
fasthttp.ReleaseCookie(tmp)
}
cj.hostCookies[hostStr] = cookies
}

// Release releases all stored cookies. After this, the CookieJar is empty.
Expand All @@ -232,10 +293,42 @@ func (cj *CookieJar) Release() {
func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie {
for _, c := range cookies {
if bytes.Equal(key, c.Key()) {
if len(path) <= 1 || bytes.HasPrefix(c.Path(), path) {
if pathMatch(path, c.Path()) {
return c
}
}
}
return nil
}

// pathMatch determines whether the request path matches the cookie path
// according to RFC 6265 section 5.1.4.
func pathMatch(reqPath, cookiePath []byte) bool {
if len(reqPath) == 0 {
reqPath = []byte("/")
}
if len(cookiePath) == 0 {
cookiePath = []byte("/")
}
if bytes.Equal(reqPath, cookiePath) {
return true
}
if !bytes.HasPrefix(reqPath, cookiePath) {
return false
}
if cookiePath[len(cookiePath)-1] == '/' {
return true
}
return len(reqPath) > len(cookiePath) && reqPath[len(cookiePath)] == '/'
}

// domainMatch reports whether host domain-matches the given cookie domain.
func domainMatch(host, domain string) bool {
host = utils.ToLower(host)
domain = utils.UnsafeString(utils.TrimLeft(utils.UnsafeBytes(domain), '.'))
domain = utils.ToLower(domain)
if host == domain {
return true
}
return strings.HasSuffix(host, "."+domain)
}
Loading
Loading