@@ -266,32 +266,62 @@ func TestRegisterCustomClientReserved(t *testing.T) {
266
266
}
267
267
268
268
func TestRoundTripRetryQueryError (t * testing.T ) {
269
- count := 0
270
- ts := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
271
- if count == 0 {
272
- count ++
273
- w .WriteHeader (http .StatusServiceUnavailable )
274
- return
275
- }
276
- w .WriteHeader (http .StatusOK )
277
- json .NewEncoder (w ).Encode (& stmtResponse {
278
- Error : ErrTrino {
279
- ErrorName : "TEST" ,
280
- },
281
- })
282
- }))
269
+ testcases := []struct {
270
+ Name string
271
+ HttpStatus int
272
+ ExpectedErrorStatus string
273
+ }{
274
+ {
275
+ Name : "Test retry 502 Bad Gateway" ,
276
+ HttpStatus : http .StatusBadGateway ,
277
+ ExpectedErrorStatus : "200 OK" ,
278
+ },
279
+ {
280
+ Name : "Test retry 503 Service Unavailable" ,
281
+ HttpStatus : http .StatusServiceUnavailable ,
282
+ ExpectedErrorStatus : "200 OK" ,
283
+ },
284
+ {
285
+ Name : "Test retry 504 Gateway Timeout" ,
286
+ HttpStatus : http .StatusGatewayTimeout ,
287
+ ExpectedErrorStatus : "200 OK" ,
288
+ },
289
+ {
290
+ Name : "Test no retry 404 Not Found" ,
291
+ HttpStatus : http .StatusNotFound ,
292
+ ExpectedErrorStatus : "404 Not Found" ,
293
+ },
294
+ }
295
+ for _ , tc := range testcases {
296
+ t .Run (tc .Name , func (t * testing.T ) {
297
+ count := 0
298
+ ts := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
299
+ if count == 0 {
300
+ count ++
301
+ w .WriteHeader (tc .HttpStatus )
302
+ return
303
+ }
304
+ w .WriteHeader (http .StatusOK )
305
+ json .NewEncoder (w ).Encode (& stmtResponse {
306
+ Error : ErrTrino {
307
+ ErrorName : "TEST" ,
308
+ },
309
+ })
310
+ }))
283
311
284
- t .Cleanup (ts .Close )
312
+ t .Cleanup (ts .Close )
285
313
286
- db , err := sql .Open ("trino" , ts .URL )
287
- require .NoError (t , err )
314
+ db , err := sql .Open ("trino" , ts .URL )
315
+ require .NoError (t , err )
288
316
289
- t .Cleanup (func () {
290
- assert .NoError (t , db .Close ())
291
- })
317
+ t .Cleanup (func () {
318
+ assert .NoError (t , db .Close ())
319
+ })
292
320
293
- _ , err = db .Query ("SELECT 1" )
294
- assert .IsTypef (t , new (ErrQueryFailed ), err , "unexpected error: %w" , err )
321
+ _ , err = db .Query ("SELECT 1" )
322
+ assert .ErrorContains (t , err , tc .ExpectedErrorStatus , "unexpected error: %w" , err )
323
+ })
324
+ }
295
325
}
296
326
297
327
func TestRoundTripBogusData (t * testing.T ) {
0 commit comments