Skip to content

Commit dd69839

Browse files
authored
feat(driver): add CallActionWithContext (#107)
Signed-off-by: Me0wo <[email protected]>
1 parent ccc03e3 commit dd69839

File tree

6 files changed

+37
-23
lines changed

6 files changed

+37
-23
lines changed

api.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package zero
22

33
import (
44
"bytes"
5+
"context"
56
"crypto/md5"
67
"encoding/base64"
78
"encoding/hex"
@@ -10,6 +11,7 @@ import (
1011
"fmt"
1112
"regexp"
1213
"strconv"
14+
"time"
1315

1416
log "github.com/sirupsen/logrus"
1517
"github.com/tidwall/gjson"
@@ -51,11 +53,17 @@ func formatMessage(msg interface{}) string {
5153

5254
// CallAction 调用 cqhttp API
5355
func (ctx *Ctx) CallAction(action string, params Params) APIResponse {
54-
req := APIRequest{
56+
c, cancel := context.WithTimeout(context.Background(), time.Minute)
57+
defer cancel()
58+
return ctx.CallActionWithContext(c, action, params)
59+
}
60+
61+
// CallActionWithContext 使用 context 调用 cqhttp API
62+
func (ctx *Ctx) CallActionWithContext(c context.Context, action string, params Params) APIResponse {
63+
rsp, err := ctx.caller.CallAPI(c, APIRequest{
5564
Action: action,
5665
Params: params,
57-
}
58-
rsp, err := ctx.caller.CallAPI(req)
66+
})
5967
if err != nil {
6068
log.Errorln("[api] 调用", action, "时出现错误: ", err)
6169
}

bot.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package zero
22

33
import (
4+
"context"
45
"encoding/json"
56
"hash/crc64"
67
"runtime/debug"
@@ -37,7 +38,7 @@ var APICallers callerMap
3738

3839
// APICaller is the interface of CallAPI
3940
type APICaller interface {
40-
CallAPI(request APIRequest) (APIResponse, error)
41+
CallAPI(c context.Context, request APIRequest) (APIResponse, error)
4142
}
4243

4344
// Driver 与OneBot通信的驱动,使用driver.DefaultWebSocketDriver
@@ -137,14 +138,14 @@ type messageLogger struct {
137138
}
138139

139140
// CallAPI 记录被触发的回复消息
140-
func (m *messageLogger) CallAPI(request APIRequest) (rsp APIResponse, err error) {
141+
func (m *messageLogger) CallAPI(ctx context.Context, request APIRequest) (rsp APIResponse, err error) {
141142
noLog := false
142143
b, ok := request.Params["__zerobot_no_log_mseeage_id__"].(bool)
143144
if ok {
144145
noLog = b
145146
delete(request.Params, "__zerobot_no_log_mseeage_id__")
146147
}
147-
rsp, err = m.caller.CallAPI(request)
148+
rsp, err = m.caller.CallAPI(ctx, request)
148149
if err != nil {
149150
return
150151
}

driver/http.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package driver
22

33
import (
44
"bytes"
5+
"context"
56
"crypto/hmac"
67
"crypto/sha1"
78
"encoding/hex"
@@ -30,7 +31,9 @@ type HTTP struct {
3031

3132
func (h *HTTP) Connect() {
3233
log.Infof("[httpcaller] 正在尝试与服务器握手: %s", h.caller.URL)
33-
rsp, err := h.caller.CallAPI(zero.APIRequest{Action: "get_login_info", Params: nil})
34+
c, cancel := context.WithTimeout(context.Background(), time.Minute)
35+
defer cancel()
36+
rsp, err := h.caller.CallAPI(c, zero.APIRequest{Action: "get_login_info", Params: nil})
3437
if err != nil {
3538
log.Warningf("[httpcaller] 与服务器握手失败: %s\n%v", h.caller.URL, err)
3639
return
@@ -148,8 +151,8 @@ func (h *HTTP) Listen(handler func([]byte, zero.APICaller)) {
148151

149152
// httpCaller 对 api 进行调用
150153
// 不关闭body会导致资源泄漏!
151-
func (c *HTTPCaller) httpCaller(action string, payload []byte) (*http.Response, error) {
152-
req, err := http.NewRequest(http.MethodPost, c.URL+"/"+action, bytes.NewBuffer(payload))
154+
func (c *HTTPCaller) httpCaller(ctx context.Context, action string, payload []byte) (*http.Response, error) {
155+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.URL+"/"+action, bytes.NewBuffer(payload))
153156
if err != nil {
154157
return nil, err
155158
}
@@ -169,13 +172,13 @@ func (c *HTTPCaller) httpCaller(action string, payload []byte) (*http.Response,
169172
return resp, nil
170173
}
171174

172-
func (c *HTTPCaller) CallAPI(request zero.APIRequest) (zero.APIResponse, error) {
173-
p, err := json.Marshal(request.Params)
175+
func (c *HTTPCaller) CallAPI(ctx context.Context, req zero.APIRequest) (zero.APIResponse, error) {
176+
p, err := json.Marshal(req.Params)
174177
if err != nil {
175178
return nullResponse, err
176179
}
177180

178-
resp, err := c.httpCaller(request.Action, p)
181+
resp, err := c.httpCaller(ctx, req.Action, p)
179182
if err != nil {
180183
return nullResponse, err
181184
}

driver/wsclient.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
package driver
22

33
import (
4+
"context"
45
"encoding/base64"
56
"io"
67
"net"
78
"net/http"
8-
"os"
99
"strings"
1010
"sync"
1111
"sync/atomic"
@@ -142,7 +142,7 @@ func (ws *WSClient) nextSeq() uint64 {
142142
}
143143

144144
// CallAPI 发送ws请求
145-
func (ws *WSClient) CallAPI(req zero.APIRequest) (zero.APIResponse, error) {
145+
func (ws *WSClient) CallAPI(c context.Context, req zero.APIRequest) (zero.APIResponse, error) {
146146
ch := make(chan zero.APIResponse, 1)
147147
req.Echo = ws.nextSeq()
148148
ws.seqMap.Store(req.Echo, ch)
@@ -163,7 +163,7 @@ func (ws *WSClient) CallAPI(req zero.APIRequest) (zero.APIResponse, error) {
163163
return nullResponse, io.ErrClosedPipe
164164
}
165165
return rsp, nil
166-
case <-time.After(time.Minute):
167-
return nullResponse, os.ErrDeadlineExceeded
166+
case <-c.Done():
167+
return nullResponse, c.Err()
168168
}
169169
}

driver/wsserver.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package driver
22

33
import (
4+
"context"
45
"encoding/json"
56
"io"
67
"net"
78
"net/http"
89
"net/url"
9-
"os"
1010
"strings"
1111
"sync"
1212
"sync/atomic"
@@ -213,7 +213,7 @@ func (wssc *WSSCaller) nextSeq() uint64 {
213213
}
214214

215215
// CallAPI 发送ws请求
216-
func (wssc *WSSCaller) CallAPI(req zero.APIRequest) (zero.APIResponse, error) {
216+
func (wssc *WSSCaller) CallAPI(c context.Context, req zero.APIRequest) (zero.APIResponse, error) {
217217
ch := make(chan zero.APIResponse, 1)
218218
req.Echo = wssc.nextSeq()
219219
wssc.seqMap.Store(req.Echo, ch)
@@ -234,7 +234,7 @@ func (wssc *WSSCaller) CallAPI(req zero.APIRequest) (zero.APIResponse, error) {
234234
return nullResponse, io.ErrClosedPipe
235235
}
236236
return rsp, nil
237-
case <-time.After(time.Minute):
238-
return nullResponse, os.ErrDeadlineExceeded
237+
case <-c.Done():
238+
return nullResponse, c.Err()
239239
}
240240
}

pattern_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
package zero
22

33
import (
4+
"context"
45
"fmt"
6+
"strconv"
7+
"testing"
8+
59
"github.com/stretchr/testify/assert"
610
"github.com/tidwall/gjson"
711
"github.com/wdvxdr1123/ZeroBot/message"
8-
"strconv"
9-
"testing"
1012
)
1113

1214
type mockAPICaller struct{}
1315

14-
func (m mockAPICaller) CallAPI(_ APIRequest) (APIResponse, error) {
16+
func (m mockAPICaller) CallAPI(_ context.Context, _ APIRequest) (APIResponse, error) {
1517
return APIResponse{
1618
Status: "",
1719
Data: gjson.Parse(`{"message_id":"12345","sender":{"user_id":12345}}`), // just for reply cleaner

0 commit comments

Comments
 (0)