@@ -448,8 +448,8 @@ func (c *Conn) Close() error {
448
448
return nil
449
449
}
450
450
451
- func (c * Conn ) newRequest (method , url string , body io.Reader , hs http.Header ) (* http.Request , error ) {
452
- req , err := http .NewRequest ( method , url , body )
451
+ func (c * Conn ) newRequest (ctx context. Context , method , url string , body io.Reader , hs http.Header ) (* http.Request , error ) {
452
+ req , err := http .NewRequestWithContext ( ctx , method , url , body )
453
453
if err != nil {
454
454
return nil , fmt .Errorf ("trino: %w" , err )
455
455
}
@@ -485,14 +485,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response
485
485
case <- ctx .Done ():
486
486
return nil , ctx .Err ()
487
487
case <- timer .C :
488
- timeout := DefaultQueryTimeout
489
- if deadline , ok := ctx .Deadline (); ok {
490
- timeout = time .Until (deadline )
491
- }
492
- client := c .httpClient
493
- client .Timeout = timeout
494
- req .Cancel = ctx .Done ()
495
- resp , err := client .Do (req )
488
+ resp , err := c .httpClient .Do (req )
496
489
if err != nil {
497
490
return nil , & ErrQueryFailed {Reason : err }
498
491
}
@@ -845,13 +838,19 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
845
838
}
846
839
}
847
840
848
- req , err := st .conn .newRequest ("POST" , st .conn .baseURL + "/v1/statement" , strings .NewReader (query ), hs )
841
+ var cancel context.CancelFunc = func () {}
842
+ if _ , ok := ctx .Deadline (); ! ok {
843
+ ctx , cancel = context .WithTimeout (ctx , DefaultQueryTimeout )
844
+ }
845
+ req , err := st .conn .newRequest (ctx , "POST" , st .conn .baseURL + "/v1/statement" , strings .NewReader (query ), hs )
849
846
if err != nil {
847
+ cancel ()
850
848
return nil , err
851
849
}
852
850
853
851
resp , err := st .conn .roundTrip (ctx , req )
854
852
if err != nil {
853
+ cancel ()
855
854
return nil , err
856
855
}
857
856
@@ -861,6 +860,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
861
860
d .UseNumber ()
862
861
err = d .Decode (& sr )
863
862
if err != nil {
863
+ cancel ()
864
864
return nil , fmt .Errorf ("trino: %w" , err )
865
865
}
866
866
@@ -879,8 +879,12 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
879
879
}
880
880
hs := make (http.Header )
881
881
hs .Add (trinoUserHeader , st .user )
882
- req , err := st .conn .newRequest ("GET" , nextURI , nil , hs )
882
+ req , err := st .conn .newRequest (ctx , "GET" , nextURI , nil , hs )
883
883
if err != nil {
884
+ if ctx .Err () == context .Canceled {
885
+ st .errors <- context .Canceled
886
+ return
887
+ }
884
888
st .errors <- err
885
889
return
886
890
}
@@ -905,6 +909,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
905
909
}()
906
910
go func () {
907
911
defer close (st .queryResponses )
912
+ defer cancel ()
908
913
for {
909
914
select {
910
915
case resp := <- st .httpResponses :
@@ -1011,12 +1016,12 @@ func (qr *driverRows) Close() error {
1011
1016
if qr .stmt .user != "" {
1012
1017
hs .Add (trinoUserHeader , qr .stmt .user )
1013
1018
}
1014
- req , err := qr .stmt .conn .newRequest ("DELETE" , qr .stmt .conn .baseURL + "/v1/query/" + url .PathEscape (qr .queryID ), nil , hs )
1019
+ ctx , cancel := context .WithTimeout (context .WithoutCancel (qr .ctx ), DefaultCancelQueryTimeout )
1020
+ defer cancel ()
1021
+ req , err := qr .stmt .conn .newRequest (ctx , "DELETE" , qr .stmt .conn .baseURL + "/v1/query/" + url .PathEscape (qr .queryID ), nil , hs )
1015
1022
if err != nil {
1016
1023
return err
1017
1024
}
1018
- ctx , cancel := context .WithTimeout (context .Background (), DefaultCancelQueryTimeout )
1019
- defer cancel ()
1020
1025
resp , err := qr .stmt .conn .roundTrip (ctx , req )
1021
1026
if err != nil {
1022
1027
qferr , ok := err .(* ErrQueryFailed )
0 commit comments