Skip to content

Commit bdabe1c

Browse files
nineinchnicklosipiuk
authored andcommitted
Pass parent context to requests
1 parent 38fd110 commit bdabe1c

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

trino/trino.go

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,8 @@ func (c *Conn) Close() error {
448448
return nil
449449
}
450450

451-
func (c *Conn) newRequest(method, url string, body io.Reader, hs http.Header) (*http.Request, error) {
452-
req, err := http.NewRequest(method, url, body)
451+
func (c *Conn) newRequest(ctx context.Context, method, url string, body io.Reader, hs http.Header) (*http.Request, error) {
452+
req, err := http.NewRequestWithContext(ctx, method, url, body)
453453
if err != nil {
454454
return nil, fmt.Errorf("trino: %w", err)
455455
}
@@ -485,14 +485,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response
485485
case <-ctx.Done():
486486
return nil, ctx.Err()
487487
case <-timer.C:
488-
timeout := DefaultQueryTimeout
489-
if deadline, ok := ctx.Deadline(); ok {
490-
timeout = time.Until(deadline)
491-
}
492-
client := c.httpClient
493-
client.Timeout = timeout
494-
req.Cancel = ctx.Done()
495-
resp, err := client.Do(req)
488+
resp, err := c.httpClient.Do(req)
496489
if err != nil {
497490
return nil, &ErrQueryFailed{Reason: err}
498491
}
@@ -845,13 +838,19 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
845838
}
846839
}
847840

848-
req, err := st.conn.newRequest("POST", st.conn.baseURL+"/v1/statement", strings.NewReader(query), hs)
841+
var cancel context.CancelFunc = func() {}
842+
if _, ok := ctx.Deadline(); !ok {
843+
ctx, cancel = context.WithTimeout(ctx, DefaultQueryTimeout)
844+
}
845+
req, err := st.conn.newRequest(ctx, "POST", st.conn.baseURL+"/v1/statement", strings.NewReader(query), hs)
849846
if err != nil {
847+
cancel()
850848
return nil, err
851849
}
852850

853851
resp, err := st.conn.roundTrip(ctx, req)
854852
if err != nil {
853+
cancel()
855854
return nil, err
856855
}
857856

@@ -861,6 +860,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
861860
d.UseNumber()
862861
err = d.Decode(&sr)
863862
if err != nil {
863+
cancel()
864864
return nil, fmt.Errorf("trino: %w", err)
865865
}
866866

@@ -879,8 +879,12 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
879879
}
880880
hs := make(http.Header)
881881
hs.Add(trinoUserHeader, st.user)
882-
req, err := st.conn.newRequest("GET", nextURI, nil, hs)
882+
req, err := st.conn.newRequest(ctx, "GET", nextURI, nil, hs)
883883
if err != nil {
884+
if ctx.Err() == context.Canceled {
885+
st.errors <- context.Canceled
886+
return
887+
}
884888
st.errors <- err
885889
return
886890
}
@@ -905,6 +909,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
905909
}()
906910
go func() {
907911
defer close(st.queryResponses)
912+
defer cancel()
908913
for {
909914
select {
910915
case resp := <-st.httpResponses:
@@ -1011,12 +1016,12 @@ func (qr *driverRows) Close() error {
10111016
if qr.stmt.user != "" {
10121017
hs.Add(trinoUserHeader, qr.stmt.user)
10131018
}
1014-
req, err := qr.stmt.conn.newRequest("DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs)
1019+
ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout)
1020+
defer cancel()
1021+
req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs)
10151022
if err != nil {
10161023
return err
10171024
}
1018-
ctx, cancel := context.WithTimeout(context.Background(), DefaultCancelQueryTimeout)
1019-
defer cancel()
10201025
resp, err := qr.stmt.conn.roundTrip(ctx, req)
10211026
if err != nil {
10221027
qferr, ok := err.(*ErrQueryFailed)

0 commit comments

Comments
 (0)