Skip to content

Commit 51aa0ec

Browse files
authored
[gzhttp] Add supported decompress request body (#1002)
* [gzhttp] Add supported decompress request body
1 parent 13c1244 commit 51aa0ec

File tree

2 files changed

+67
-11
lines changed

2 files changed

+67
-11
lines changed

gzhttp/compress.go

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,11 @@ func NewWrapper(opts ...option) (func(http.Handler) http.HandlerFunc, error) {
464464
return func(h http.Handler) http.HandlerFunc {
465465
return func(w http.ResponseWriter, r *http.Request) {
466466
w.Header().Add(vary, acceptEncoding)
467+
if c.allowCompressedRequests && contentGzip(r) {
468+
r.Header.Del(contentEncoding)
469+
r.Body = &gzipReader{body: r.Body}
470+
}
471+
467472
if acceptsGzip(r) {
468473
gw := grwPool.Get().(*GzipResponseWriter)
469474
*gw = GzipResponseWriter{
@@ -536,17 +541,18 @@ func (pct parsedContentType) equals(mediaType string, params map[string]string)
536541

537542
// Used for functional configuration.
538543
type config struct {
539-
minSize int
540-
level int
541-
writer writer.GzipWriterFactory
542-
contentTypes func(ct string) bool
543-
keepAcceptRanges bool
544-
setContentType bool
545-
suffixETag string
546-
dropETag bool
547-
jitterBuffer int
548-
randomJitter string
549-
sha256Jitter bool
544+
minSize int
545+
level int
546+
writer writer.GzipWriterFactory
547+
contentTypes func(ct string) bool
548+
keepAcceptRanges bool
549+
setContentType bool
550+
suffixETag string
551+
dropETag bool
552+
jitterBuffer int
553+
randomJitter string
554+
sha256Jitter bool
555+
allowCompressedRequests bool
550556
}
551557

552558
func (c *config) validate() error {
@@ -579,6 +585,15 @@ func MinSize(size int) option {
579585
}
580586
}
581587

588+
// AllowCompressedRequests will enable or disable RFC 7694 compressed requests.
589+
// By default this is Disabled.
590+
// See https://datatracker.ietf.org/doc/html/rfc7694
591+
func AllowCompressedRequests(b bool) option {
592+
return func(c *config) {
593+
c.allowCompressedRequests = b
594+
}
595+
}
596+
582597
// CompressionLevel sets the compression level
583598
func CompressionLevel(level int) option {
584599
return func(c *config) {
@@ -752,6 +767,12 @@ func RandomJitter(n, buffer int, paranoid bool) option {
752767
}
753768
}
754769

770+
// contentGzip returns true if the given HTTP request indicates that it gzipped.
771+
func contentGzip(r *http.Request) bool {
772+
// See more detail in `acceptsGzip`
773+
return r.Method != http.MethodHead && r.Body != nil && parseEncodingGzip(r.Header.Get(contentEncoding)) > 0
774+
}
775+
755776
// acceptsGzip returns true if the given HTTP request indicates that it will
756777
// accept a gzipped response.
757778
func acceptsGzip(r *http.Request) bool {

gzhttp/compress_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,41 @@ func TestMustNewGzipHandler(t *testing.T) {
8989
handler.ServeHTTP(res3, req3)
9090

9191
assertEqual(t, http.DetectContentType([]byte(testBody)), res3.Header().Get("Content-Type"))
92+
93+
// send compress request body with `AllowCompressedRequests`
94+
handler = newTestHandlerLevel(testBody, AllowCompressedRequests(true))
95+
96+
var b bytes.Buffer
97+
writerGzip := gzip.NewWriter(&b)
98+
writerGzip.Write(testBody)
99+
writerGzip.Close()
100+
101+
req5, _ := http.NewRequest("POST", "/whatever", &b)
102+
req5.Header.Set("Content-Encoding", "gzip")
103+
resp5 := httptest.NewRecorder()
104+
handler.ServeHTTP(resp5, req5)
105+
res5 := resp5.Result()
106+
107+
assertEqual(t, 200, res5.StatusCode)
108+
109+
body, _ := io.ReadAll(res5.Body)
110+
assertEqual(t, len(testBody), len(body))
111+
112+
// send compress request body without `AllowCompressedRequests`
113+
writerGzip = gzip.NewWriter(&b)
114+
writerGzip.Write(testBody)
115+
writerGzip.Close()
116+
117+
handler = newTestHandlerLevel(b.Bytes())
118+
119+
req6, _ := http.NewRequest("POST", "/whatever", &b)
120+
resp6 := httptest.NewRecorder()
121+
handler.ServeHTTP(resp6, req6)
122+
res6 := resp6.Result()
123+
124+
assertEqual(t, 200, res6.StatusCode)
125+
body, _ = io.ReadAll(res6.Body)
126+
assertEqual(t, b.Len(), len(body))
92127
}
93128

94129
func TestGzipHandlerSmallBodyNoCompression(t *testing.T) {

0 commit comments

Comments
 (0)