Skip to content

Commit 1f20d98

Browse files
fix(dot/subscription): check websocket message from untrusted data (#2527)
* fix: websocket message checks from untrusted data
1 parent e29f90c commit 1f20d98

File tree

3 files changed

+95
-44
lines changed

3 files changed

+95
-44
lines changed

dot/rpc/http.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
233233
wsc := NewWSConn(ws, h.serverConfig)
234234
h.wsConns = append(h.wsConns, wsc)
235235

236-
go wsc.HandleComm()
236+
go wsc.HandleConn()
237237
}
238238

239239
// NewWSConn to create new WebSocket Connection struct

dot/rpc/subscription/websocket.go

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,21 @@ import (
2222
"github.com/gorilla/websocket"
2323
)
2424

25+
type websocketMessage struct {
26+
ID float64 `json:"id"`
27+
Method string `json:"method"`
28+
Params any `json:"params"`
29+
}
30+
2531
type httpclient interface {
2632
Do(*http.Request) (*http.Response, error)
2733
}
2834

29-
var errCannotReadFromWebsocket = errors.New("cannot read message from websocket")
30-
var errCannotUnmarshalMessage = errors.New("cannot unmarshal webasocket message data")
35+
var (
36+
errCannotReadFromWebsocket = errors.New("cannot read message from websocket")
37+
errEmptyMethod = errors.New("empty method")
38+
)
39+
3140
var logger = log.NewFromGlobal(log.AddContext("pkg", "rpc/subscription"))
3241

3342
// WSConn struct to hold WebSocket Connection references
@@ -46,87 +55,82 @@ type WSConn struct {
4655
}
4756

4857
// readWebsocketMessage will read and parse the message data to a string->interface{} data
49-
func (c *WSConn) readWebsocketMessage() ([]byte, map[string]interface{}, error) {
50-
_, mbytes, err := c.Wsconn.ReadMessage()
58+
func (c *WSConn) readWebsocketMessage() (rawBytes []byte, wsMessage *websocketMessage, err error) {
59+
_, rawBytes, err = c.Wsconn.ReadMessage()
5160
if err != nil {
52-
logger.Debugf("websocket failed to read message: %s", err)
53-
return nil, nil, errCannotReadFromWebsocket
61+
return nil, nil, fmt.Errorf("%w: %s", errCannotReadFromWebsocket, err.Error())
5462
}
5563

56-
logger.Tracef("websocket message received: %s", string(mbytes))
57-
58-
// determine if request is for subscribe method type
59-
var msg map[string]interface{}
60-
err = json.Unmarshal(mbytes, &msg)
61-
64+
wsMessage = new(websocketMessage)
65+
err = json.Unmarshal(rawBytes, wsMessage)
6266
if err != nil {
63-
logger.Debugf("websocket failed to unmarshal request message: %s", err)
64-
return nil, nil, errCannotUnmarshalMessage
67+
return nil, nil, err
6568
}
6669

67-
return mbytes, msg, nil
70+
if wsMessage.Method == "" {
71+
return nil, nil, errEmptyMethod
72+
}
73+
74+
return rawBytes, wsMessage, nil
6875
}
6976

70-
//HandleComm handles messages received on websocket connections
71-
func (c *WSConn) HandleComm() {
77+
// HandleConn handles messages received on websocket connections
78+
func (c *WSConn) HandleConn() {
7279
for {
73-
mbytes, msg, err := c.readWebsocketMessage()
74-
if errors.Is(err, errCannotReadFromWebsocket) {
75-
return
76-
}
80+
rawBytes, wsMessage, err := c.readWebsocketMessage()
81+
if err != nil {
82+
logger.Debugf("websocket failed to read message: %s", err)
83+
if errors.Is(err, errCannotReadFromWebsocket) {
84+
return
85+
}
7786

78-
if errors.Is(err, errCannotUnmarshalMessage) {
7987
c.safeSendError(0, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
8088
continue
8189
}
8290

83-
params := msg["params"]
84-
reqid := msg["id"].(float64)
85-
method := msg["method"].(string)
86-
87-
logger.Debugf("ws method %s called with params %v", method, params)
91+
logger.Tracef("websocket message received: %s", string(rawBytes))
92+
logger.Debugf("ws method %s called with params %v", wsMessage.Method, wsMessage.Params)
8893

89-
if !strings.Contains(method, "_unsubscribe") && !strings.Contains(method, "_unwatch") {
90-
setupListener := c.getSetupListener(method)
94+
if !strings.Contains(wsMessage.Method, "_unsubscribe") && !strings.Contains(wsMessage.Method, "_unwatch") {
95+
setupListener := c.getSetupListener(wsMessage.Method)
9196

9297
if setupListener == nil {
93-
c.executeRPCCall(mbytes)
98+
c.executeRPCCall(rawBytes)
9499
continue
95100
}
96101

97-
listener, err := setupListener(reqid, params)
102+
listener, err := setupListener(wsMessage.ID, wsMessage.Params)
98103
if err != nil {
99-
logger.Warnf("failed to create listener (method=%s): %s", method, err)
104+
logger.Warnf("failed to create listener (method=%s): %s", wsMessage.Method, err)
100105
continue
101106
}
102107

103108
listener.Listen()
104109
continue
105110
}
106111

107-
listener, err := c.getUnsubListener(params)
108-
112+
listener, err := c.getUnsubListener(wsMessage.Params)
109113
if err != nil {
110-
logger.Warnf("failed to get unsubscriber (method=%s): %s", method, err)
114+
logger.Warnf("failed to get unsubscriber (method=%s): %s", wsMessage.Method, err)
111115

112116
if errors.Is(err, errUknownParamSubscribeID) || errors.Is(err, errCannotFindUnsubsriber) {
113-
c.safeSendError(reqid, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
117+
c.safeSendError(wsMessage.ID, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
114118
continue
115119
}
116120

117121
if errors.Is(err, errCannotParseID) || errors.Is(err, errCannotFindListener) {
118-
c.safeSend(newBooleanResponseJSON(false, reqid))
122+
c.safeSend(newBooleanResponseJSON(false, wsMessage.ID))
119123
continue
120124
}
121125
}
122126

123127
err = listener.Stop()
124128
if err != nil {
125-
logger.Warnf("failed to stop listener goroutine (method=%s): %s", method, err)
126-
c.safeSend(newBooleanResponseJSON(false, reqid))
129+
logger.Warnf("failed to stop listener goroutine (method=%s): %s", wsMessage.Method, err)
130+
c.safeSend(newBooleanResponseJSON(false, wsMessage.ID))
127131
}
128132

129-
c.safeSend(newBooleanResponseJSON(true, reqid))
133+
c.safeSend(newBooleanResponseJSON(true, wsMessage.ID))
130134
continue
131135
}
132136
}

dot/rpc/subscription/websocket_test.go

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ import (
2121
"github.com/stretchr/testify/require"
2222
)
2323

24-
func TestWSConn_HandleComm(t *testing.T) {
24+
func TestWSConn_HandleConn(t *testing.T) {
2525
wsconn, c, cancel := setupWSConn(t)
2626
wsconn.Subscriptions = make(map[uint32]Listener)
2727
defer cancel()
2828

29-
go wsconn.HandleComm()
29+
go wsconn.HandleConn()
3030
time.Sleep(time.Second * 2)
3131

3232
// test storageChangeListener
@@ -294,7 +294,7 @@ func TestSubscribeAllHeads(t *testing.T) {
294294
wsconn.Subscriptions = make(map[uint32]Listener)
295295
defer cancel()
296296

297-
go wsconn.HandleComm()
297+
go wsconn.HandleConn()
298298
time.Sleep(time.Second * 2)
299299

300300
_, err := wsconn.initAllBlocksListerner(1, nil)
@@ -372,3 +372,50 @@ func TestSubscribeAllHeads(t *testing.T) {
372372
require.NoError(t, l.Stop())
373373
mockBlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block"))
374374
}
375+
376+
func TestWSConn_CheckWebsocketInvalidData(t *testing.T) {
377+
wsconn, c, cancel := setupWSConn(t)
378+
wsconn.Subscriptions = make(map[uint32]Listener)
379+
defer cancel()
380+
381+
go wsconn.HandleConn()
382+
383+
tests := []struct {
384+
sentMessage []byte
385+
expected []byte
386+
}{
387+
{
388+
sentMessage: []byte(`{
389+
"jsonrpc": "2.0",
390+
"method": "",
391+
"id": 0,
392+
"params": []
393+
}`),
394+
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
395+
},
396+
{
397+
sentMessage: []byte(`{
398+
"jsonrpc": "2.0",
399+
"params": []
400+
}`),
401+
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
402+
},
403+
{
404+
sentMessage: []byte(`{
405+
"jsonrpc": "2.0",
406+
"id": "abcdef"
407+
"method": "some_method_name"
408+
"params": []
409+
}`),
410+
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
411+
},
412+
}
413+
414+
for _, tt := range tests {
415+
c.WriteMessage(websocket.TextMessage, tt.sentMessage)
416+
417+
_, msg, err := c.ReadMessage()
418+
require.NoError(t, err)
419+
require.Equal(t, tt.expected, msg)
420+
}
421+
}

0 commit comments

Comments
 (0)