Skip to content

Commit 9d4063e

Browse files
feat: make request info available in env for preheated kernels (#1109)
* feat: make request info available in env for preheated kernels The mechanism that was used for 'get_query_string()' is changed so that all request info is sent. For backward compatibility 'get_query_string()' is reimplemented by using the changed code. * refactor: no need to make a copy of environ in this case Co-authored-by: Maarten Breddels <[email protected]> * docs: add documentation for wait_for_request * refactor: remove backward compatibility for get_query_string Co-authored-by: Maarten Breddels <[email protected]>
1 parent 35fb38e commit 9d4063e

File tree

5 files changed

+45
-40
lines changed

5 files changed

+45
-40
lines changed

docs/source/customize.rst

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,16 @@ In normal mode, Voilà users can get the `query string` at run time through the
399399
import os
400400
query_string = os.getenv('QUERY_STRING')
401401
402-
In preheating kernel mode, users can just replace the ``os.getenv`` call with the helper ``get_query_string`` from ``voila.utils``
402+
In preheating kernel mode, users can prepend with ``wait_for_request`` from ``voila.utils``
403403

404404
.. code-block:: python
405405
406-
from voila.utils import get_query_string
407-
query_string = get_query_string()
406+
import os
407+
from voila.utils import wait_for_request
408+
wait_for_request()
409+
query_string = os.getenv('QUERY_STRING')
408410
409-
``get_query_string`` will pause the execution of the notebook in the preheated kernel at this cell and wait for an actual user to connect to Voilà, then ``get_query_string`` will return the URL `query string` and continue the execution of the remaining cells.
411+
``wait_for_request`` will pause the execution of the notebook in the preheated kernel at this cell and wait for an actual user to connect to Voilà, set the request info environment variables and then continue the execution of the remaining cells.
410412

411413
If the Voilà websocket handler is not started with the default protocol (`ws`), the default IP address (`127.0.0.1`) or the default port (`8866`), users need to provide these values through the environment variables ``VOILA_APP_PROTOCOL``, ``VOILA_APP_IP`` and ``VOILA_APP_PORT``. The easiest way is to set these variables in the `voila.json` configuration file, for example:
412414

voila/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from .exporter import VoilaExporter
6363
from .shutdown_kernel_handler import VoilaShutdownKernelHandler
6464
from .voila_kernel_manager import voila_kernel_manager_factory
65-
from .query_parameters_handler import QueryStringSocketHandler
65+
from .request_info_handler import RequestInfoSocketHandler
6666
from .utils import create_include_assets_functions
6767

6868
_kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
@@ -500,7 +500,7 @@ def start(self):
500500
handlers.append(
501501
(
502502
url_path_join(self.server_url, r'/voila/query/%s' % _kernel_id_regex),
503-
QueryStringSocketHandler
503+
RequestInfoSocketHandler
504504
)
505505
)
506506
# Serving notebook extensions

voila/handler.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ._version import __version__
2424
from .notebook_renderer import NotebookRenderer
25-
from .query_parameters_handler import QueryStringSocketHandler
25+
from .request_info_handler import RequestInfoSocketHandler
2626
from .utils import ENV_VARIABLE, create_include_assets_functions
2727

2828

@@ -80,25 +80,25 @@ async def get_generator(self, path=None):
8080
cwd = os.path.dirname(notebook_path)
8181

8282
# Adding request uri to kernel env
83-
kernel_env = os.environ.copy()
84-
kernel_env[ENV_VARIABLE.SCRIPT_NAME] = self.request.path
85-
kernel_env[
83+
request_info = dict()
84+
request_info[ENV_VARIABLE.SCRIPT_NAME] = self.request.path
85+
request_info[
8686
ENV_VARIABLE.PATH_INFO
8787
] = '' # would be /foo/bar if voila.ipynb/foo/bar was supported
88-
kernel_env[ENV_VARIABLE.QUERY_STRING] = str(self.request.query)
89-
kernel_env[ENV_VARIABLE.SERVER_SOFTWARE] = 'voila/{}'.format(__version__)
90-
kernel_env[ENV_VARIABLE.SERVER_PROTOCOL] = str(self.request.version)
88+
request_info[ENV_VARIABLE.QUERY_STRING] = str(self.request.query)
89+
request_info[ENV_VARIABLE.SERVER_SOFTWARE] = 'voila/{}'.format(__version__)
90+
request_info[ENV_VARIABLE.SERVER_PROTOCOL] = str(self.request.version)
9191
host, port = split_host_and_port(self.request.host.lower())
92-
kernel_env[ENV_VARIABLE.SERVER_PORT] = str(port) if port else ''
93-
kernel_env[ENV_VARIABLE.SERVER_NAME] = host
92+
request_info[ENV_VARIABLE.SERVER_PORT] = str(port) if port else ''
93+
request_info[ENV_VARIABLE.SERVER_NAME] = host
9494
# Add HTTP Headers as env vars following rfc3875#section-4.1.18
9595
if len(self.voila_configuration.http_header_envs) > 0:
9696
for header_name in self.request.headers:
9797
config_headers_lower = [header.lower() for header in self.voila_configuration.http_header_envs]
9898
# Use case insensitive comparison of header names as per rfc2616#section-4.2
9999
if header_name.lower() in config_headers_lower:
100100
env_name = f'HTTP_{header_name.upper().replace("-", "_")}'
101-
kernel_env[env_name] = self.request.headers.get(header_name)
101+
request_info[env_name] = self.request.headers.get(header_name)
102102

103103
template_arg = self.get_argument("voila-template", None)
104104
theme_arg = self.get_argument("voila-theme", None)
@@ -132,7 +132,7 @@ async def get_generator(self, path=None):
132132
notebook_name=notebook_path,
133133
)
134134

135-
QueryStringSocketHandler.send_updates({'kernel_id': kernel_id, 'payload': self.request.query})
135+
RequestInfoSocketHandler.send_updates({'kernel_id': kernel_id, 'payload': request_info})
136136
# Send rendered cell to frontend
137137
if len(rendered_cache) > 0:
138138
yield ''.join(rendered_cache)
@@ -183,6 +183,7 @@ def time_out():
183183

184184
return '<script>voila_heartbeat()</script>\n'
185185

186+
kernel_env = {**os.environ, **request_info}
186187
kernel_env[ENV_VARIABLE.VOILA_PREHEAT] = 'False'
187188
kernel_env[ENV_VARIABLE.VOILA_BASE_URL] = self.base_url
188189
kernel_id = await ensure_async(

voila/query_parameters_handler.py renamed to voila/request_info_handler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
from typing import Dict
44

55

6-
class QueryStringSocketHandler(WebSocketHandler):
7-
"""A websocket handler used to provide the query string
6+
class RequestInfoSocketHandler(WebSocketHandler):
7+
"""A websocket handler used to provide the request info
88
assocciated with kernel ids in preheat kernel mode.
99
1010
Class variables
1111
---------------
1212
- _waiters : A dictionary which holds the `websocket` connection
1313
assocciated with the kernel id.
1414
15-
- cache : A dictionary which holds the query string assocciated
15+
- cache : A dictionary which holds the request info assocciated
1616
with the kernel id.
1717
"""
1818
_waiters = dict()
@@ -26,28 +26,28 @@ def open(self, kernel_id: str) -> None:
2626
kernel_id (str): Kernel id used by the notebook when it opens
2727
the websocket connection.
2828
"""
29-
QueryStringSocketHandler._waiters[kernel_id] = self
29+
RequestInfoSocketHandler._waiters[kernel_id] = self
3030
if kernel_id in self._cache:
3131
self.write_message(self._cache[kernel_id])
3232

3333
def on_close(self) -> None:
34-
for k_id, waiter in QueryStringSocketHandler._waiters.items():
34+
for k_id, waiter in RequestInfoSocketHandler._waiters.items():
3535
if waiter == self:
3636
break
37-
del QueryStringSocketHandler._waiters[k_id]
37+
del RequestInfoSocketHandler._waiters[k_id]
3838

3939
@classmethod
40-
def send_updates(cls: 'QueryStringSocketHandler', msg: Dict) -> None:
41-
"""Class method used to dispath the query string to the waiting
42-
notebook. This method is called in `VoilaHandler` when the query
43-
string becomes available.
40+
def send_updates(cls: 'RequestInfoSocketHandler', msg: Dict) -> None:
41+
"""Class method used to dispath the request info to the waiting
42+
notebook. This method is called in `VoilaHandler` when the request
43+
info becomes available.
4444
If this method is called before the opening of websocket connection,
4545
`msg` is stored in `_cache0` and the message will be dispatched when
4646
a notebook with coresponding kernel id is connected.
4747
4848
Args:
4949
- msg (Dict): this dictionary contains the `kernel_id` to identify
50-
the waiting notebook and `payload` is the query string.
50+
the waiting notebook and `payload` is the request info.
5151
"""
5252
kernel_id = msg['kernel_id']
5353
payload = msg['payload']

voila/utils.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import threading
1414
from enum import Enum
1515
from typing import Awaitable
16+
import json
1617

1718
import websockets
1819

@@ -58,29 +59,29 @@ def get_server_root_dir(settings):
5859
return root_dir
5960

6061

61-
async def _get_query_string(ws_url: str) -> Awaitable:
62+
async def _get_request_info(ws_url: str) -> Awaitable:
6263
async with websockets.connect(ws_url) as websocket:
63-
qs = await websocket.recv()
64-
return qs
64+
ri = await websocket.recv()
65+
return ri
6566

6667

67-
def get_query_string(url: str = None) -> str:
68+
def wait_for_request(url: str = None) -> str:
6869
"""Helper function to pause the execution of notebook and wait for
69-
the query string.
70+
the pre-heated kernel to be used and all request info is added to
71+
the environment.
7072
7173
Args:
72-
url (str, optional): Address to get user query string, if it is not
74+
url (str, optional): Address to get request info, if it is not
7375
provided, `voila` will figure out from the environment variables.
7476
Defaults to None.
7577
76-
Returns: The query string provided by `QueryStringSocketHandler`.
7778
"""
7879

7980
preheat_mode = os.getenv(ENV_VARIABLE.VOILA_PREHEAT, 'False')
8081
if preheat_mode == 'False':
81-
return os.getenv(ENV_VARIABLE.QUERY_STRING)
82+
return
8283

83-
query_string = None
84+
request_info = None
8485
if url is None:
8586
protocol = os.getenv(ENV_VARIABLE.VOILA_APP_PROTOCOL, 'ws')
8687
server_ip = os.getenv(ENV_VARIABLE.VOILA_APP_IP, '127.0.0.1')
@@ -92,9 +93,9 @@ def get_query_string(url: str = None) -> str:
9293
ws_url = f'{url}/{kernel_id}'
9394

9495
def inner():
95-
nonlocal query_string
96+
nonlocal request_info
9697
loop = asyncio.new_event_loop()
97-
query_string = loop.run_until_complete(_get_query_string(ws_url))
98+
request_info = loop.run_until_complete(_get_request_info(ws_url))
9899

99100
thread = threading.Thread(target=inner)
100101
try:
@@ -103,7 +104,8 @@ def inner():
103104
except (KeyboardInterrupt, SystemExit):
104105
asyncio.get_event_loop().stop()
105106

106-
return query_string
107+
for k, v in json.loads(request_info).items():
108+
os.environ[k] = v
107109

108110

109111
def make_url(template_name, base_url, path):

0 commit comments

Comments
 (0)