5
5
from jupyter_server .services .kernels .websocket import (
6
6
KernelWebsocketHandler as WebsocketHandler ,
7
7
)
8
+ from jupyter_server .services .kernels .connection .base import (
9
+ deserialize_binary_message ,
10
+ deserialize_msg_from_ws_v1 ,
11
+ )
8
12
except ImportError :
9
13
from jupyter_server .services .kernels .handlers import (
10
14
ZMQChannelsHandler as WebsocketHandler ,
@@ -27,9 +31,77 @@ def read_header_from_binary_message(ws_msg: bytes) -> Optional[Dict]:
27
31
28
32
29
33
class VoilaKernelWebsocketHandler (WebsocketHandler ):
34
+
35
+ _execution_data = {}
36
+
37
+ def on_message (self , ws_msg ):
38
+ connection = self .connection
39
+ subprotocol = connection .subprotocol
40
+ if not connection .channels :
41
+ # already closed, ignore the message
42
+ connection .log .debug ("Received message on closed websocket %r" , ws_msg )
43
+ return
44
+
45
+ if subprotocol == "v1.kernel.websocket.jupyter.org" :
46
+ channel , msg_list = deserialize_msg_from_ws_v1 (ws_msg )
47
+ msg = {"header" : None , "content" : None }
48
+ else :
49
+ if isinstance (ws_msg , bytes ): # type:ignore[unreachable]
50
+ msg = deserialize_binary_message (ws_msg ) # type:ignore[unreachable]
51
+ else :
52
+ msg = json .loads (ws_msg )
53
+ msg_list = []
54
+ channel = msg .pop ("channel" , None )
55
+
56
+ if channel is None :
57
+ connection .log .warning ("No channel specified, assuming shell: %s" , msg )
58
+ channel = "shell"
59
+ if channel not in connection .channels :
60
+ connection .log .warning ("No such channel: %r" , channel )
61
+ return
62
+ am = connection .multi_kernel_manager .allowed_message_types
63
+ ignore_msg = False
64
+ msg_header = connection .get_part ("header" , msg ["header" ], msg_list )
65
+ msg_content = connection .get_part ("content" , msg ["content" ], msg_list )
66
+ if msg_header ["msg_type" ] == "execute_request" :
67
+ execution_data = self ._execution_data .get (self .kernel_id , None )
68
+ cells = execution_data ["cells" ]
69
+ code = msg_content .get ("code" )
70
+ try :
71
+ cell_idx = int (code )
72
+ cell = cells [cell_idx ]
73
+ if cell ["cell_type" ] != "code" :
74
+ cell ["source" ] = ""
75
+
76
+ if subprotocol == "v1.kernel.websocket.jupyter.org" :
77
+ msg_content ["code" ] = cell ["source" ]
78
+ msg_list [3 ] = connection .session .pack (msg_content )
79
+ else :
80
+ msg ["content" ]["code" ] = cell ["source" ]
81
+
82
+ except Exception :
83
+ connection .log .warning ("Unsupported code cell %s" % code )
84
+
85
+ if am :
86
+ msg ["header" ] = connection .get_part ("header" , msg ["header" ], msg_list )
87
+ assert msg ["header" ] is not None
88
+ if msg ["header" ]["msg_type" ] not in am : # type:ignore[unreachable]
89
+ connection .log .warning (
90
+ 'Received message of type "%s", which is not allowed. Ignoring.'
91
+ % msg ["header" ]["msg_type" ]
92
+ )
93
+ ignore_msg = True
94
+ if not ignore_msg :
95
+ stream = connection .channels [channel ]
96
+ if subprotocol == "v1.kernel.websocket.jupyter.org" :
97
+ connection .session .send_raw (stream , msg_list )
98
+ else :
99
+ connection .session .send (stream , msg )
100
+
30
101
def write_message (
31
102
self , message : Union [bytes , Dict [str , Any ]], binary : bool = False
32
103
):
104
+
33
105
if isinstance (message , bytes ):
34
106
header = read_header_from_binary_message (message )
35
107
elif isinstance (message , dict ):
0 commit comments