Skip to content

Commit 4b5ebbd

Browse files
mathetakemissBerg
andauthored
extproc: clear content-encoding only when body is mutated (#685)
**Commit Message** Clearing content-type header is only necessary when the response mutation is actually happening. Otherwise, say when the original response is in gzip and not mutation is happening, the client gets gzip and no content-encoding header. **Related Issues/PRs (if applicable)** Follow up on #683 --------- Signed-off-by: Takeshi Yoneda <[email protected]> Co-authored-by: Erica Hughberg <[email protected]>
1 parent fdeb5aa commit 4b5ebbd

File tree

5 files changed

+65
-28
lines changed

5 files changed

+65
-28
lines changed

internal/extproc/chatcompletion_processor.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,16 +271,14 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessResponseBody(ctx context.
271271
c.metrics.RecordRequestCompletion(ctx, err == nil)
272272
}()
273273
var br io.Reader
274-
var removeHeaders []string
274+
var isGzip bool
275275
switch c.responseEncoding {
276276
case "gzip":
277277
br, err = gzip.NewReader(bytes.NewReader(body.Body))
278278
if err != nil {
279279
return nil, fmt.Errorf("failed to decode gzip: %w", err)
280280
}
281-
// TODO: this is a hotfix, we should update this to recompress since its in the header
282-
// If the response was gzipped, ensure we remove the content-encoding header
283-
removeHeaders = append(removeHeaders, "content-encoding")
281+
isGzip = true
284282
default:
285283
br = bytes.NewReader(body.Body)
286284
}
@@ -289,10 +287,18 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessResponseBody(ctx context.
289287
if err != nil {
290288
return nil, fmt.Errorf("failed to transform response: %w", err)
291289
}
292-
if headerMutation == nil {
293-
headerMutation = &extprocv3.HeaderMutation{}
290+
if bodyMutation != nil && isGzip {
291+
if headerMutation == nil {
292+
headerMutation = &extprocv3.HeaderMutation{}
293+
}
294+
// TODO: this is a hotfix, we should update this to recompress since its in the header
295+
// If the response was gzipped, ensure we remove the content-encoding header.
296+
//
297+
// This is only needed when the transformation is actually modifying the body. When the backend
298+
// is in OpenAI format (and it's the first try before any retry), the response body is not modified,
299+
// so we don't need to remove the header in that case.
300+
headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, "content-encoding")
294301
}
295-
headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, removeHeaders...)
296302

297303
resp := &extprocv3.ProcessingResponse{
298304
Response: &extprocv3.ProcessingResponse_ResponseBody{

tests/extproc/testupstream_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,18 @@ func TestWithTestUpstream(t *testing.T) {
164164
expStatus: http.StatusOK,
165165
expResponseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`,
166166
},
167+
{
168+
name: "openai - /v1/chat/completions - gzip",
169+
backend: "openai",
170+
responseType: "gzip",
171+
path: "/v1/chat/completions",
172+
method: http.MethodPost,
173+
requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}]}`,
174+
expPath: "/v1/chat/completions",
175+
responseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`,
176+
expStatus: http.StatusOK,
177+
expResponseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`,
178+
},
167179
{
168180
name: "azure-openai - /v1/chat/completions",
169181
backend: "azure-openai",

tests/internal/testupstreamlib/testupstream.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ const (
1212
// Each line in x-response-body is treated as a separate [data] payload.
1313
// * If this is "aws-event-stream", the response body is expected to be an AWS Event Stream.
1414
// Each line in x-response-body is treated as a separate event payload.
15+
// * If this is gzip, the response body is expected to be a regular JSON response compressed with gzip.
1516
// * If this is empty, the response body is expected to be a regular JSON response.
1617
ResponseTypeKey = "x-response-type"
1718
// ExpectedHeadersKey is the key for the expected headers in the request.

tests/internal/testupstreamlib/testupstream/main.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package main
77

88
import (
99
"bytes"
10+
"compress/gzip"
1011
"encoding/base64"
1112
"fmt"
1213
"io"
@@ -287,8 +288,11 @@ func handler(w http.ResponseWriter, r *http.Request) {
287288
logger.Println("response sent")
288289
r.Context().Done()
289290
default:
290-
w.Header().Set("Content-Type", "application/json")
291-
291+
isGzip := r.Header.Get(testupstreamlib.ResponseTypeKey) == "gzip"
292+
if isGzip {
293+
w.Header().Set("content-encoding", "gzip")
294+
}
295+
w.Header().Set("content-type", "application/json")
292296
var responseBody []byte
293297
if expResponseBody := r.Header.Get(testupstreamlib.ResponseBodyHeaderKey); expResponseBody == "" {
294298
// If the expected response body is not set, get the fake response if the path is known.
@@ -308,6 +312,13 @@ func handler(w http.ResponseWriter, r *http.Request) {
308312
}
309313

310314
w.WriteHeader(status)
315+
if isGzip {
316+
var buf bytes.Buffer
317+
gz := gzip.NewWriter(&buf)
318+
_, _ = gz.Write(responseBody)
319+
_ = gz.Close()
320+
responseBody = buf.Bytes()
321+
}
311322
_, _ = w.Write(responseBody)
312323
logger.Println("response sent:", string(responseBody))
313324
}

tests/internal/testupstreamlib/testupstream/main_test.go

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -218,30 +218,37 @@ func Test_main(t *testing.T) {
218218

219219
t.Run("fake response", func(t *testing.T) {
220220
t.Parallel()
221-
request, err := http.NewRequest("GET",
222-
"http://"+l.Addr().String()+"/v1/chat/completions", bytes.NewBuffer([]byte("expected request body")))
223-
require.NoError(t, err)
221+
for _, isGzip := range []bool{false, true} {
222+
t.Run(fmt.Sprintf("gzip=%t", isGzip), func(t *testing.T) {
223+
request, err := http.NewRequest("GET",
224+
"http://"+l.Addr().String()+"/v1/chat/completions", bytes.NewBuffer([]byte("expected request body")))
225+
require.NoError(t, err)
224226

225-
request.Header.Set(testupstreamlib.ExpectedPathHeaderKey,
226-
base64.StdEncoding.EncodeToString([]byte("/v1/chat/completions")))
227-
request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey,
228-
base64.StdEncoding.EncodeToString([]byte("expected request body")))
227+
request.Header.Set(testupstreamlib.ExpectedPathHeaderKey,
228+
base64.StdEncoding.EncodeToString([]byte("/v1/chat/completions")))
229+
request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey,
230+
base64.StdEncoding.EncodeToString([]byte("expected request body")))
231+
if isGzip {
232+
request.Header.Set(testupstreamlib.ResponseTypeKey, "gzip")
233+
}
229234

230-
response, err := http.DefaultClient.Do(request)
231-
require.NoError(t, err)
232-
defer func() {
233-
_ = response.Body.Close()
234-
}()
235+
response, err := http.DefaultClient.Do(request)
236+
require.NoError(t, err)
237+
defer func() {
238+
_ = response.Body.Close()
239+
}()
235240

236-
require.Equal(t, http.StatusOK, response.StatusCode)
241+
require.Equal(t, http.StatusOK, response.StatusCode)
237242

238-
responseBody, err := io.ReadAll(response.Body)
239-
require.NoError(t, err)
243+
responseBody, err := io.ReadAll(response.Body)
244+
require.NoError(t, err)
240245

241-
var chat openai.ChatCompletion
242-
require.NoError(t, chat.UnmarshalJSON(responseBody))
243-
// Ensure that the response is one of the fake responses.
244-
require.Contains(t, chatCompletionFakeResponses, chat.Choices[0].Message.Content)
246+
var chat openai.ChatCompletion
247+
require.NoError(t, chat.UnmarshalJSON(responseBody))
248+
// Ensure that the response is one of the fake responses.
249+
require.Contains(t, chatCompletionFakeResponses, chat.Choices[0].Message.Content)
250+
})
251+
}
245252
})
246253

247254
t.Run("fake response for unknown path", func(t *testing.T) {

0 commit comments

Comments
 (0)