Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions voila/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from .exporter import VoilaExporter
from .shutdown_kernel_handler import VoilaShutdownKernelHandler
from .voila_kernel_manager import voila_kernel_manager_factory
from .query_parameters_handler import QueryStringSocketHandler
from .request_info_handler import RequestInfoSocketHandler
from .utils import create_include_assets_functions

_kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
Expand Down Expand Up @@ -500,7 +500,7 @@ def start(self):
handlers.append(
(
url_path_join(self.server_url, r'/voila/query/%s' % _kernel_id_regex),
QueryStringSocketHandler
RequestInfoSocketHandler
)
)
# Serving notebook extensions
Expand Down
23 changes: 12 additions & 11 deletions voila/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ._version import __version__
from .notebook_renderer import NotebookRenderer
from .query_parameters_handler import QueryStringSocketHandler
from .request_info_handler import RequestInfoSocketHandler
from .utils import ENV_VARIABLE


Expand All @@ -47,25 +47,25 @@ async def get_generator(self, path=None):
cwd = os.path.dirname(notebook_path)

# Adding request uri to kernel env
kernel_env = os.environ.copy()
kernel_env[ENV_VARIABLE.SCRIPT_NAME] = self.request.path
kernel_env[
request_info = dict()
request_info[ENV_VARIABLE.SCRIPT_NAME] = self.request.path
request_info[
ENV_VARIABLE.PATH_INFO
] = '' # would be /foo/bar if voila.ipynb/foo/bar was supported
kernel_env[ENV_VARIABLE.QUERY_STRING] = str(self.request.query)
kernel_env[ENV_VARIABLE.SERVER_SOFTWARE] = 'voila/{}'.format(__version__)
kernel_env[ENV_VARIABLE.SERVER_PROTOCOL] = str(self.request.version)
request_info[ENV_VARIABLE.QUERY_STRING] = str(self.request.query)
request_info[ENV_VARIABLE.SERVER_SOFTWARE] = 'voila/{}'.format(__version__)
request_info[ENV_VARIABLE.SERVER_PROTOCOL] = str(self.request.version)
host, port = split_host_and_port(self.request.host.lower())
kernel_env[ENV_VARIABLE.SERVER_PORT] = str(port) if port else ''
kernel_env[ENV_VARIABLE.SERVER_NAME] = host
request_info[ENV_VARIABLE.SERVER_PORT] = str(port) if port else ''
request_info[ENV_VARIABLE.SERVER_NAME] = host
# Add HTTP Headers as env vars following rfc3875#section-4.1.18
if len(self.voila_configuration.http_header_envs) > 0:
for header_name in self.request.headers:
config_headers_lower = [header.lower() for header in self.voila_configuration.http_header_envs]
# Use case insensitive comparison of header names as per rfc2616#section-4.2
if header_name.lower() in config_headers_lower:
env_name = f'HTTP_{header_name.upper().replace("-", "_")}'
kernel_env[env_name] = self.request.headers.get(header_name)
request_info[env_name] = self.request.headers.get(header_name)

template_arg = self.get_argument("voila-template", None)
theme_arg = self.get_argument("voila-theme", None)
Expand Down Expand Up @@ -99,7 +99,7 @@ async def get_generator(self, path=None):
notebook_name=notebook_path,
)

QueryStringSocketHandler.send_updates({'kernel_id': kernel_id, 'payload': self.request.query})
RequestInfoSocketHandler.send_updates({'kernel_id': kernel_id, 'payload': request_info})
# Send rendered cell to frontend
if len(rendered_cache) > 0:
yield ''.join(rendered_cache)
Expand Down Expand Up @@ -150,6 +150,7 @@ def time_out():

return '<script>voila_heartbeat()</script>\n'

kernel_env = {**os.environ.copy(), **request_info}
kernel_env[ENV_VARIABLE.VOILA_PREHEAT] = 'False'
kernel_env[ENV_VARIABLE.VOILA_BASE_URL] = self.base_url
kernel_id = await ensure_async(
Expand Down
22 changes: 11 additions & 11 deletions voila/query_parameters_handler.py → voila/request_info_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from typing import Dict


class QueryStringSocketHandler(WebSocketHandler):
"""A websocket handler used to provide the query string
class RequestInfoSocketHandler(WebSocketHandler):
"""A websocket handler used to provide the request info
assocciated with kernel ids in preheat kernel mode.

Class variables
---------------
- _waiters : A dictionary which holds the `websocket` connection
assocciated with the kernel id.

- cache : A dictionary which holds the query string assocciated
- cache : A dictionary which holds the request info assocciated
with the kernel id.
"""
_waiters = dict()
Expand All @@ -26,28 +26,28 @@ def open(self, kernel_id: str) -> None:
kernel_id (str): Kernel id used by the notebook when it opens
the websocket connection.
"""
QueryStringSocketHandler._waiters[kernel_id] = self
RequestInfoSocketHandler._waiters[kernel_id] = self
if kernel_id in self._cache:
self.write_message(self._cache[kernel_id])

def on_close(self) -> None:
for k_id, waiter in QueryStringSocketHandler._waiters.items():
for k_id, waiter in RequestInfoSocketHandler._waiters.items():
if waiter == self:
break
del QueryStringSocketHandler._waiters[k_id]
del RequestInfoSocketHandler._waiters[k_id]

@classmethod
def send_updates(cls: 'QueryStringSocketHandler', msg: Dict) -> None:
"""Class method used to dispath the query string to the waiting
notebook. This method is called in `VoilaHandler` when the query
string becomes available.
def send_updates(cls: 'RequestInfoSocketHandler', msg: Dict) -> None:
"""Class method used to dispath the request info to the waiting
notebook. This method is called in `VoilaHandler` when the request
info becomes available.
If this method is called before the opening of websocket connection,
`msg` is stored in `_cache0` and the message will be dispatched when
a notebook with coresponding kernel id is connected.

Args:
- msg (Dict): this dictionary contains the `kernel_id` to identify
the waiting notebook and `payload` is the query string.
the waiting notebook and `payload` is the request info.
"""
kernel_id = msg['kernel_id']
payload = msg['payload']
Expand Down
42 changes: 30 additions & 12 deletions voila/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import threading
from enum import Enum
from typing import Awaitable
import json

import websockets

Expand Down Expand Up @@ -58,29 +59,29 @@ def get_server_root_dir(settings):
return root_dir


async def _get_query_string(ws_url: str) -> Awaitable:
async def _get_request_info(ws_url: str) -> Awaitable:
async with websockets.connect(ws_url) as websocket:
qs = await websocket.recv()
return qs
ri = await websocket.recv()
return ri


def get_query_string(url: str = None) -> str:
def wait_for_request(url: str = None) -> str:
"""Helper function to pause the execution of notebook and wait for
the query string.
the pre-heated kernel to be used and all request info is added to
the environment.

Args:
url (str, optional): Address to get user query string, if it is not
url (str, optional): Address to get request info, if it is not
provided, `voila` will figure out from the environment variables.
Defaults to None.

Returns: The query string provided by `QueryStringSocketHandler`.
"""

preheat_mode = os.getenv(ENV_VARIABLE.VOILA_PREHEAT, 'False')
if preheat_mode == 'False':
return os.getenv(ENV_VARIABLE.QUERY_STRING)
return

query_string = None
request_info = None
if url is None:
protocol = os.getenv(ENV_VARIABLE.VOILA_APP_PROTOCOL, 'ws')
server_ip = os.getenv(ENV_VARIABLE.VOILA_APP_IP, '127.0.0.1')
Expand All @@ -92,9 +93,9 @@ def get_query_string(url: str = None) -> str:
ws_url = f'{url}/{kernel_id}'

def inner():
nonlocal query_string
nonlocal request_info
loop = asyncio.new_event_loop()
query_string = loop.run_until_complete(_get_query_string(ws_url))
request_info = loop.run_until_complete(_get_request_info(ws_url))

thread = threading.Thread(target=inner)
try:
Expand All @@ -103,7 +104,24 @@ def inner():
except (KeyboardInterrupt, SystemExit):
asyncio.get_event_loop().stop()

return query_string
for k, v in json.loads(request_info).items():
os.environ[k] = v


def get_query_string(url: str = None) -> str:
"""Helper function to pause the execution of notebook and wait for
the query string.

Args:
url (str, optional): Address to get user query string, if it is not
provided, `voila` will figure out from the environment variables.
Defaults to None.

Returns: The query string provided by `QueryStringSocketHandler`.
"""

wait_for_request(url)
return os.getenv(ENV_VARIABLE.QUERY_STRING)


def make_url(template_name, base_url, path):
Expand Down