Skip to content

Commit 3360780

Browse files
springjdnineinchnick
authored andcommitted
Add retries for 502, 504 HTTP statuses
Add 502 Bad Gateway and 504 Gateway Timeout HTTP status to client's retry logic Add tests for retries on 502,504 and test that no retry is done on another status (404)
1 parent 70bd4d7 commit 3360780

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

trino/trino.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response
617617
}
618618
}
619619
return resp, nil
620-
case http.StatusServiceUnavailable:
620+
case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
621621
resp.Body.Close()
622622
timer.Reset(delay)
623623
delay = time.Duration(math.Min(

trino/trino_test.go

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -266,32 +266,62 @@ func TestRegisterCustomClientReserved(t *testing.T) {
266266
}
267267

268268
func TestRoundTripRetryQueryError(t *testing.T) {
269-
count := 0
270-
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
271-
if count == 0 {
272-
count++
273-
w.WriteHeader(http.StatusServiceUnavailable)
274-
return
275-
}
276-
w.WriteHeader(http.StatusOK)
277-
json.NewEncoder(w).Encode(&stmtResponse{
278-
Error: ErrTrino{
279-
ErrorName: "TEST",
280-
},
281-
})
282-
}))
269+
testcases := []struct {
270+
Name string
271+
HttpStatus int
272+
ExpectedErrorStatus string
273+
}{
274+
{
275+
Name: "Test retry 502 Bad Gateway",
276+
HttpStatus: http.StatusBadGateway,
277+
ExpectedErrorStatus: "200 OK",
278+
},
279+
{
280+
Name: "Test retry 503 Service Unavailable",
281+
HttpStatus: http.StatusServiceUnavailable,
282+
ExpectedErrorStatus: "200 OK",
283+
},
284+
{
285+
Name: "Test retry 504 Gateway Timeout",
286+
HttpStatus: http.StatusGatewayTimeout,
287+
ExpectedErrorStatus: "200 OK",
288+
},
289+
{
290+
Name: "Test no retry 404 Not Found",
291+
HttpStatus: http.StatusNotFound,
292+
ExpectedErrorStatus: "404 Not Found",
293+
},
294+
}
295+
for _, tc := range testcases {
296+
t.Run(tc.Name, func(t *testing.T) {
297+
count := 0
298+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
299+
if count == 0 {
300+
count++
301+
w.WriteHeader(tc.HttpStatus)
302+
return
303+
}
304+
w.WriteHeader(http.StatusOK)
305+
json.NewEncoder(w).Encode(&stmtResponse{
306+
Error: ErrTrino{
307+
ErrorName: "TEST",
308+
},
309+
})
310+
}))
283311

284-
t.Cleanup(ts.Close)
312+
t.Cleanup(ts.Close)
285313

286-
db, err := sql.Open("trino", ts.URL)
287-
require.NoError(t, err)
314+
db, err := sql.Open("trino", ts.URL)
315+
require.NoError(t, err)
288316

289-
t.Cleanup(func() {
290-
assert.NoError(t, db.Close())
291-
})
317+
t.Cleanup(func() {
318+
assert.NoError(t, db.Close())
319+
})
292320

293-
_, err = db.Query("SELECT 1")
294-
assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err)
321+
_, err = db.Query("SELECT 1")
322+
assert.ErrorContains(t, err, tc.ExpectedErrorStatus, "unexpected error: %w", err)
323+
})
324+
}
295325
}
296326

297327
func TestRoundTripBogusData(t *testing.T) {

0 commit comments

Comments
 (0)