diff --git a/docs/source/other/full-config.rst b/docs/source/other/full-config.rst index 52d0bcf3ff..18220cc03e 100644 --- a/docs/source/other/full-config.rst +++ b/docs/source/other/full-config.rst @@ -1178,19 +1178,19 @@ NotebookNotary.store_factory : Callable A callable returning the storage backend for notebook signatures. The default uses an SQLite database. -GatewayKernelManager.allow_tracebacks : Bool +GatewayMappingKernelManager.allow_tracebacks : Bool Default: ``True`` Whether to send tracebacks to clients on exceptions. -GatewayKernelManager.allowed_message_types : List +GatewayMappingKernelManager.allowed_message_types : List Default: ``[]`` White list of allowed kernel message types. When the list is empty, all message types are allowed. -GatewayKernelManager.buffer_offline_messages : Bool +GatewayMappingKernelManager.buffer_offline_messages : Bool Default: ``True`` Whether messages from kernels whose frontends have disconnected should be buffered in-memory. @@ -1202,36 +1202,36 @@ GatewayKernelManager.buffer_offline_messages : Bool no frontends are connected. -GatewayKernelManager.cull_busy : Bool +GatewayMappingKernelManager.cull_busy : Bool Default: ``False`` Whether to consider culling kernels which are busy. Only effective if cull_idle_timeout > 0. -GatewayKernelManager.cull_connected : Bool +GatewayMappingKernelManager.cull_connected : Bool Default: ``False`` Whether to consider culling kernels which have one or more connections. Only effective if cull_idle_timeout > 0. -GatewayKernelManager.cull_idle_timeout : Int +GatewayMappingKernelManager.cull_idle_timeout : Int Default: ``0`` Timeout (in seconds) after which a kernel is considered idle and ready to be culled. Values of 0 or lower disable culling. Very short timeouts may result in kernels being culled for users with poor network connections. -GatewayKernelManager.cull_interval : Int +GatewayMappingKernelManager.cull_interval : Int Default: ``300`` The interval (in seconds) on which to check for idle kernels exceeding the cull timeout value. -GatewayKernelManager.default_kernel_name : Unicode +GatewayMappingKernelManager.default_kernel_name : Unicode Default: ``'python3'`` The name of the default kernel to start -GatewayKernelManager.kernel_info_timeout : Float +GatewayMappingKernelManager.kernel_info_timeout : Float Default: ``60`` Timeout for giving up on a kernel (in seconds). @@ -1244,24 +1244,24 @@ GatewayKernelManager.kernel_info_timeout : Float and the ZMQChannelsHandler (which handles the startup). -GatewayKernelManager.kernel_manager_class : DottedObjectName +GatewayMappingKernelManager.kernel_manager_class : DottedObjectName Default: ``'jupyter_client.ioloop.IOLoopKernelManager'`` The kernel manager class. This is configurable to allow subclassing of the KernelManager for customized behavior. -GatewayKernelManager.root_dir : Unicode +GatewayMappingKernelManager.root_dir : Unicode Default: ``''`` No description -GatewayKernelManager.shared_context : Bool +GatewayMappingKernelManager.shared_context : Bool Default: ``True`` Share a single zmq.Context to talk to all my kernels -GatewayKernelManager.traceback_replacement_message : Unicode +GatewayMappingKernelManager.traceback_replacement_message : Unicode Default: ``'An exception occurred at runtime, which is not shown due to ...`` Message to print when allow_tracebacks is False, and an exception occurs diff --git a/jupyter_server/gateway/gateway_client.py b/jupyter_server/gateway/gateway_client.py new file mode 100644 index 0000000000..83b1369e6d --- /dev/null +++ b/jupyter_server/gateway/gateway_client.py @@ -0,0 +1,322 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import json +import os + +from socket import gaierror +from tornado import web +from tornado.httpclient import AsyncHTTPClient, HTTPError +from traitlets import Unicode, Int, Float, Bool, default, validate, TraitError +from traitlets.config import SingletonConfigurable + + +class GatewayClient(SingletonConfigurable): + """This class manages the configuration. It's its own singleton class so that we + can share these values across all objects. It also contains some helper methods + to build request arguments out of the various config options. + + """ + + url = Unicode(default_value=None, allow_none=True, config=True, + help="""The url of the Kernel or Enterprise Gateway server where + kernel specifications are defined and kernel management takes place. + If defined, this Notebook server acts as a proxy for all kernel + management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var) + """ + ) + + url_env = 'JUPYTER_GATEWAY_URL' + + @default('url') + def _url_default(self): + return os.environ.get(self.url_env) + + @validate('url') + def _url_validate(self, proposal): + value = proposal['value'] + # Ensure value, if present, starts with 'http' + if value is not None and len(value) > 0: + if not str(value).lower().startswith('http'): + raise TraitError("GatewayClient url must start with 'http': '%r'" % value) + return value + + ws_url = Unicode(default_value=None, allow_none=True, config=True, + help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value + will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) + """ + ) + + ws_url_env = 'JUPYTER_GATEWAY_WS_URL' + + @default('ws_url') + def _ws_url_default(self): + default_value = os.environ.get(self.ws_url_env) + if default_value is None: + if self.gateway_enabled: + default_value = self.url.lower().replace('http', 'ws') + return default_value + + @validate('ws_url') + def _ws_url_validate(self, proposal): + value = proposal['value'] + # Ensure value, if present, starts with 'ws' + if value is not None and len(value) > 0: + if not str(value).lower().startswith('ws'): + raise TraitError("GatewayClient ws_url must start with 'ws': '%r'" % value) + return value + + kernels_endpoint_default_value = '/api/kernels' + kernels_endpoint_env = 'JUPYTER_GATEWAY_KERNELS_ENDPOINT' + kernels_endpoint = Unicode(default_value=kernels_endpoint_default_value, config=True, + help="""The gateway API endpoint for accessing kernel resources (JUPYTER_GATEWAY_KERNELS_ENDPOINT env var)""") + + @default('kernels_endpoint') + def _kernels_endpoint_default(self): + return os.environ.get(self.kernels_endpoint_env, self.kernels_endpoint_default_value) + + kernelspecs_endpoint_default_value = '/api/kernelspecs' + kernelspecs_endpoint_env = 'JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT' + kernelspecs_endpoint = Unicode(default_value=kernelspecs_endpoint_default_value, config=True, + help="""The gateway API endpoint for accessing kernelspecs (JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT env var)""") + + @default('kernelspecs_endpoint') + def _kernelspecs_endpoint_default(self): + return os.environ.get(self.kernelspecs_endpoint_env, self.kernelspecs_endpoint_default_value) + + kernelspecs_resource_endpoint_default_value = '/kernelspecs' + kernelspecs_resource_endpoint_env = 'JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT' + kernelspecs_resource_endpoint = Unicode(default_value=kernelspecs_resource_endpoint_default_value, config=True, + help="""The gateway endpoint for accessing kernelspecs resources + (JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT env var)""") + + @default('kernelspecs_resource_endpoint') + def _kernelspecs_resource_endpoint_default(self): + return os.environ.get(self.kernelspecs_resource_endpoint_env, self.kernelspecs_resource_endpoint_default_value) + + connect_timeout_default_value = 40.0 + connect_timeout_env = 'JUPYTER_GATEWAY_CONNECT_TIMEOUT' + connect_timeout = Float(default_value=connect_timeout_default_value, config=True, + help="""The time allowed for HTTP connection establishment with the Gateway server. + (JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""") + + @default('connect_timeout') + def connect_timeout_default(self): + return float(os.environ.get('JUPYTER_GATEWAY_CONNECT_TIMEOUT', self.connect_timeout_default_value)) + + request_timeout_default_value = 40.0 + request_timeout_env = 'JUPYTER_GATEWAY_REQUEST_TIMEOUT' + request_timeout = Float(default_value=request_timeout_default_value, config=True, + help="""The time allowed for HTTP request completion. (JUPYTER_GATEWAY_REQUEST_TIMEOUT env var)""") + + @default('request_timeout') + def request_timeout_default(self): + return float(os.environ.get('JUPYTER_GATEWAY_REQUEST_TIMEOUT', self.request_timeout_default_value)) + + client_key = Unicode(default_value=None, allow_none=True, config=True, + help="""The filename for client SSL key, if any. (JUPYTER_GATEWAY_CLIENT_KEY env var) + """ + ) + client_key_env = 'JUPYTER_GATEWAY_CLIENT_KEY' + + @default('client_key') + def _client_key_default(self): + return os.environ.get(self.client_key_env) + + client_cert = Unicode(default_value=None, allow_none=True, config=True, + help="""The filename for client SSL certificate, if any. (JUPYTER_GATEWAY_CLIENT_CERT env var) + """ + ) + client_cert_env = 'JUPYTER_GATEWAY_CLIENT_CERT' + + @default('client_cert') + def _client_cert_default(self): + return os.environ.get(self.client_cert_env) + + ca_certs = Unicode(default_value=None, allow_none=True, config=True, + help="""The filename of CA certificates or None to use defaults. (JUPYTER_GATEWAY_CA_CERTS env var) + """ + ) + ca_certs_env = 'JUPYTER_GATEWAY_CA_CERTS' + + @default('ca_certs') + def _ca_certs_default(self): + return os.environ.get(self.ca_certs_env) + + http_user = Unicode(default_value=None, allow_none=True, config=True, + help="""The username for HTTP authentication. (JUPYTER_GATEWAY_HTTP_USER env var) + """ + ) + http_user_env = 'JUPYTER_GATEWAY_HTTP_USER' + + @default('http_user') + def _http_user_default(self): + return os.environ.get(self.http_user_env) + + http_pwd = Unicode(default_value=None, allow_none=True, config=True, + help="""The password for HTTP authentication. (JUPYTER_GATEWAY_HTTP_PWD env var) + """ + ) + http_pwd_env = 'JUPYTER_GATEWAY_HTTP_PWD' + + @default('http_pwd') + def _http_pwd_default(self): + return os.environ.get(self.http_pwd_env) + + headers_default_value = '{}' + headers_env = 'JUPYTER_GATEWAY_HEADERS' + headers = Unicode(default_value=headers_default_value, allow_none=True, config=True, + help="""Additional HTTP headers to pass on the request. This value will be converted to a dict. + (JUPYTER_GATEWAY_HEADERS env var) + """ + ) + + @default('headers') + def _headers_default(self): + return os.environ.get(self.headers_env, self.headers_default_value) + + auth_token = Unicode(default_value=None, allow_none=True, config=True, + help="""The authorization token used in the HTTP headers. (JUPYTER_GATEWAY_AUTH_TOKEN env var) + """ + ) + auth_token_env = 'JUPYTER_GATEWAY_AUTH_TOKEN' + + @default('auth_token') + def _auth_token_default(self): + return os.environ.get(self.auth_token_env, '') + + validate_cert_default_value = True + validate_cert_env = 'JUPYTER_GATEWAY_VALIDATE_CERT' + validate_cert = Bool(default_value=validate_cert_default_value, config=True, + help="""For HTTPS requests, determines if server's certificate should be validated or not. + (JUPYTER_GATEWAY_VALIDATE_CERT env var)""" + ) + + @default('validate_cert') + def validate_cert_default(self): + return bool(os.environ.get(self.validate_cert_env, str(self.validate_cert_default_value)) not in ['no', 'false']) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._static_args = {} # initialized on first use + + env_whitelist_default_value = '' + env_whitelist_env = 'JUPYTER_GATEWAY_ENV_WHITELIST' + env_whitelist = Unicode(default_value=env_whitelist_default_value, config=True, + help="""A comma-separated list of environment variable names that will be included, along with + their values, in the kernel startup request. The corresponding `env_whitelist` configuration + value must also be set on the Gateway server - since that configuration value indicates which + environmental values to make available to the kernel. (JUPYTER_GATEWAY_ENV_WHITELIST env var)""") + + @default('env_whitelist') + def _env_whitelist_default(self): + return os.environ.get(self.env_whitelist_env, self.env_whitelist_default_value) + + gateway_retry_interval_default_value = 1.0 + gateway_retry_interval_env = 'JUPYTER_GATEWAY_RETRY_INTERVAL' + gateway_retry_interval = Float(default_value=gateway_retry_interval_default_value, config=True, + help="""The time allowed for HTTP reconnection with the Gateway server for the first time. + Next will be JUPYTER_GATEWAY_RETRY_INTERVAL multiplied by two in factor of numbers of retries + but less than JUPYTER_GATEWAY_RETRY_INTERVAL_MAX. + (JUPYTER_GATEWAY_RETRY_INTERVAL env var)""") + + @default('gateway_retry_interval') + def gateway_retry_interval_default(self): + return float(os.environ.get('JUPYTER_GATEWAY_RETRY_INTERVAL', self.gateway_retry_interval_default_value)) + + gateway_retry_interval_max_default_value = 30.0 + gateway_retry_interval_max_env = 'JUPYTER_GATEWAY_RETRY_INTERVAL_MAX' + gateway_retry_interval_max = Float(default_value=gateway_retry_interval_max_default_value, config=True, + help="""The maximum time allowed for HTTP reconnection retry with the Gateway server. + (JUPYTER_GATEWAY_RETRY_INTERVAL_MAX env var)""") + + @default('gateway_retry_interval_max') + def gateway_retry_interval_max_default(self): + return float(os.environ.get('JUPYTER_GATEWAY_RETRY_INTERVAL_MAX', self.gateway_retry_interval_max_default_value)) + + gateway_retry_max_default_value = 5 + gateway_retry_max_env = 'JUPYTER_GATEWAY_RETRY_MAX' + gateway_retry_max = Int(default_value=gateway_retry_max_default_value, config=True, + help="""The maximum retries allowed for HTTP reconnection with the Gateway server. + (JUPYTER_GATEWAY_RETRY_MAX env var)""") + + @default('gateway_retry_max') + def gateway_retry_max_default(self): + return int(os.environ.get('JUPYTER_GATEWAY_RETRY_MAX', self.gateway_retry_max_default_value)) + + @property + def gateway_enabled(self): + return bool(self.url is not None and len(self.url) > 0) + + # Ensure KERNEL_LAUNCH_TIMEOUT has a default value. + KERNEL_LAUNCH_TIMEOUT = int(os.environ.get('KERNEL_LAUNCH_TIMEOUT', 40)) + + def init_static_args(self): + """Initialize arguments used on every request. Since these are static values, we'll + perform this operation once. + + """ + # Ensure that request timeout and KERNEL_LAUNCH_TIMEOUT are the same, taking the + # greater value of the two. + if self.request_timeout < float(GatewayClient.KERNEL_LAUNCH_TIMEOUT): + self.request_timeout = float(GatewayClient.KERNEL_LAUNCH_TIMEOUT) + elif self.request_timeout > float(GatewayClient.KERNEL_LAUNCH_TIMEOUT): + GatewayClient.KERNEL_LAUNCH_TIMEOUT = int(self.request_timeout) + # Ensure any adjustments are reflected in env. + os.environ['KERNEL_LAUNCH_TIMEOUT'] = str(GatewayClient.KERNEL_LAUNCH_TIMEOUT) + + self._static_args['headers'] = json.loads(self.headers) + if 'Authorization' not in self._static_args['headers'].keys(): + self._static_args['headers'].update({ + 'Authorization': 'token {}'.format(self.auth_token) + }) + self._static_args['connect_timeout'] = self.connect_timeout + self._static_args['request_timeout'] = self.request_timeout + self._static_args['validate_cert'] = self.validate_cert + if self.client_cert: + self._static_args['client_cert'] = self.client_cert + self._static_args['client_key'] = self.client_key + if self.ca_certs: + self._static_args['ca_certs'] = self.ca_certs + if self.http_user: + self._static_args['auth_username'] = self.http_user + if self.http_pwd: + self._static_args['auth_password'] = self.http_pwd + + def load_connection_args(self, **kwargs): + """Merges the static args relative to the connection, with the given keyword arguments. If statics + have yet to be initialized, we'll do that here. + + """ + if len(self._static_args) == 0: + self.init_static_args() + + kwargs.update(self._static_args) + return kwargs + + +async def gateway_request(endpoint, **kwargs): + """Make an async request to kernel gateway endpoint, returns a response """ + client = AsyncHTTPClient() + kwargs = GatewayClient.instance().load_connection_args(**kwargs) + try: + response = await client.fetch(endpoint, **kwargs) + # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect + # or the server is not running. + # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes + # of the tree view. + except ConnectionRefusedError as e: + raise web.HTTPError(503, "Connection refused from Gateway server url '{}'. " + "Check to be sure the Gateway instance is running.".format(GatewayClient.instance().url)) from e + except HTTPError as e: + # This can occur if the host is valid (e.g., foo.com) but there's nothing there. + raise web.HTTPError(e.code, "Error attempting to connect to Gateway server url '{}'. " + "Ensure gateway url is valid and the Gateway instance is running.". + format(GatewayClient.instance().url)) from e + except gaierror as e: + raise web.HTTPError(404, "The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. " + "Ensure gateway url is valid and the Gateway instance is running.". + format(GatewayClient.instance().url)) from e + + return response + diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 61a9a4bb4d..0fb1cb2234 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -1,352 +1,49 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -import os +import datetime import json +import os +import websocket -from socket import gaierror +from jupyter_client.asynchronous.client import AsyncKernelClient +from jupyter_client.clientabc import KernelClientABC +from jupyter_client.kernelspec import KernelSpecManager +from jupyter_client.manager import AsyncKernelManager +from jupyter_client.managerabc import KernelManagerABC + +from logging import Logger +from queue import Queue +from threading import Thread from tornado import web -from tornado.escape import json_encode, json_decode, url_escape -from tornado.httpclient import HTTPClient, AsyncHTTPClient, HTTPError +from tornado.escape import json_encode, json_decode, url_escape, utf8 +from traitlets import Instance, DottedObjectName, Type, default +from typing import Dict +from .gateway_client import GatewayClient, gateway_request from ..services.kernels.kernelmanager import AsyncMappingKernelManager from ..services.sessions.sessionmanager import SessionManager +from ..utils import url_path_join, ensure_async +from .._tz import UTC -from jupyter_client.kernelspec import KernelSpecManager -from ..utils import url_path_join - -from traitlets import Instance, Unicode, Int, Float, Bool, default, validate, TraitError -from traitlets.config import SingletonConfigurable - - -class GatewayClient(SingletonConfigurable): - """This class manages the configuration. It's its own singleton class so that we - can share these values across all objects. It also contains some helper methods - to build request arguments out of the various config options. - - """ - - url = Unicode(default_value=None, allow_none=True, config=True, - help="""The url of the Kernel or Enterprise Gateway server where - kernel specifications are defined and kernel management takes place. - If defined, this Notebook server acts as a proxy for all kernel - management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var) - """ - ) - - url_env = 'JUPYTER_GATEWAY_URL' - - @default('url') - def _url_default(self): - return os.environ.get(self.url_env) - - @validate('url') - def _url_validate(self, proposal): - value = proposal['value'] - # Ensure value, if present, starts with 'http' - if value is not None and len(value) > 0: - if not str(value).lower().startswith('http'): - raise TraitError("GatewayClient url must start with 'http': '%r'" % value) - return value - - ws_url = Unicode(default_value=None, allow_none=True, config=True, - help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value - will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) - """ - ) - - ws_url_env = 'JUPYTER_GATEWAY_WS_URL' - - @default('ws_url') - def _ws_url_default(self): - default_value = os.environ.get(self.ws_url_env) - if default_value is None: - if self.gateway_enabled: - default_value = self.url.lower().replace('http', 'ws') - return default_value - - @validate('ws_url') - def _ws_url_validate(self, proposal): - value = proposal['value'] - # Ensure value, if present, starts with 'ws' - if value is not None and len(value) > 0: - if not str(value).lower().startswith('ws'): - raise TraitError("GatewayClient ws_url must start with 'ws': '%r'" % value) - return value - - kernels_endpoint_default_value = '/api/kernels' - kernels_endpoint_env = 'JUPYTER_GATEWAY_KERNELS_ENDPOINT' - kernels_endpoint = Unicode(default_value=kernels_endpoint_default_value, config=True, - help="""The gateway API endpoint for accessing kernel resources (JUPYTER_GATEWAY_KERNELS_ENDPOINT env var)""") - - @default('kernels_endpoint') - def _kernels_endpoint_default(self): - return os.environ.get(self.kernels_endpoint_env, self.kernels_endpoint_default_value) - - kernelspecs_endpoint_default_value = '/api/kernelspecs' - kernelspecs_endpoint_env = 'JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT' - kernelspecs_endpoint = Unicode(default_value=kernelspecs_endpoint_default_value, config=True, - help="""The gateway API endpoint for accessing kernelspecs (JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT env var)""") - - @default('kernelspecs_endpoint') - def _kernelspecs_endpoint_default(self): - return os.environ.get(self.kernelspecs_endpoint_env, self.kernelspecs_endpoint_default_value) - - kernelspecs_resource_endpoint_default_value = '/kernelspecs' - kernelspecs_resource_endpoint_env = 'JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT' - kernelspecs_resource_endpoint = Unicode(default_value=kernelspecs_resource_endpoint_default_value, config=True, - help="""The gateway endpoint for accessing kernelspecs resources - (JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT env var)""") - - @default('kernelspecs_resource_endpoint') - def _kernelspecs_resource_endpoint_default(self): - return os.environ.get(self.kernelspecs_resource_endpoint_env, self.kernelspecs_resource_endpoint_default_value) - - connect_timeout_default_value = 40.0 - connect_timeout_env = 'JUPYTER_GATEWAY_CONNECT_TIMEOUT' - connect_timeout = Float(default_value=connect_timeout_default_value, config=True, - help="""The time allowed for HTTP connection establishment with the Gateway server. - (JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""") - - @default('connect_timeout') - def connect_timeout_default(self): - return float(os.environ.get('JUPYTER_GATEWAY_CONNECT_TIMEOUT', self.connect_timeout_default_value)) - - request_timeout_default_value = 40.0 - request_timeout_env = 'JUPYTER_GATEWAY_REQUEST_TIMEOUT' - request_timeout = Float(default_value=request_timeout_default_value, config=True, - help="""The time allowed for HTTP request completion. (JUPYTER_GATEWAY_REQUEST_TIMEOUT env var)""") - - @default('request_timeout') - def request_timeout_default(self): - return float(os.environ.get('JUPYTER_GATEWAY_REQUEST_TIMEOUT', self.request_timeout_default_value)) - - client_key = Unicode(default_value=None, allow_none=True, config=True, - help="""The filename for client SSL key, if any. (JUPYTER_GATEWAY_CLIENT_KEY env var) - """ - ) - client_key_env = 'JUPYTER_GATEWAY_CLIENT_KEY' - - @default('client_key') - def _client_key_default(self): - return os.environ.get(self.client_key_env) - - client_cert = Unicode(default_value=None, allow_none=True, config=True, - help="""The filename for client SSL certificate, if any. (JUPYTER_GATEWAY_CLIENT_CERT env var) - """ - ) - client_cert_env = 'JUPYTER_GATEWAY_CLIENT_CERT' - - @default('client_cert') - def _client_cert_default(self): - return os.environ.get(self.client_cert_env) - - ca_certs = Unicode(default_value=None, allow_none=True, config=True, - help="""The filename of CA certificates or None to use defaults. (JUPYTER_GATEWAY_CA_CERTS env var) - """ - ) - ca_certs_env = 'JUPYTER_GATEWAY_CA_CERTS' - - @default('ca_certs') - def _ca_certs_default(self): - return os.environ.get(self.ca_certs_env) - - http_user = Unicode(default_value=None, allow_none=True, config=True, - help="""The username for HTTP authentication. (JUPYTER_GATEWAY_HTTP_USER env var) - """ - ) - http_user_env = 'JUPYTER_GATEWAY_HTTP_USER' - - @default('http_user') - def _http_user_default(self): - return os.environ.get(self.http_user_env) - - http_pwd = Unicode(default_value=None, allow_none=True, config=True, - help="""The password for HTTP authentication. (JUPYTER_GATEWAY_HTTP_PWD env var) - """ - ) - http_pwd_env = 'JUPYTER_GATEWAY_HTTP_PWD' - - @default('http_pwd') - def _http_pwd_default(self): - return os.environ.get(self.http_pwd_env) - - headers_default_value = '{}' - headers_env = 'JUPYTER_GATEWAY_HEADERS' - headers = Unicode(default_value=headers_default_value, allow_none=True, config=True, - help="""Additional HTTP headers to pass on the request. This value will be converted to a dict. - (JUPYTER_GATEWAY_HEADERS env var) - """ - ) - - @default('headers') - def _headers_default(self): - return os.environ.get(self.headers_env, self.headers_default_value) - - auth_token = Unicode(default_value=None, allow_none=True, config=True, - help="""The authorization token used in the HTTP headers. (JUPYTER_GATEWAY_AUTH_TOKEN env var) - """ - ) - auth_token_env = 'JUPYTER_GATEWAY_AUTH_TOKEN' - - @default('auth_token') - def _auth_token_default(self): - return os.environ.get(self.auth_token_env, '') - - validate_cert_default_value = True - validate_cert_env = 'JUPYTER_GATEWAY_VALIDATE_CERT' - validate_cert = Bool(default_value=validate_cert_default_value, config=True, - help="""For HTTPS requests, determines if server's certificate should be validated or not. - (JUPYTER_GATEWAY_VALIDATE_CERT env var)""" - ) - - @default('validate_cert') - def validate_cert_default(self): - return bool(os.environ.get(self.validate_cert_env, str(self.validate_cert_default_value)) not in ['no', 'false']) - - def __init__(self, **kwargs): - super(GatewayClient, self).__init__(**kwargs) - self._static_args = {} # initialized on first use - - env_whitelist_default_value = '' - env_whitelist_env = 'JUPYTER_GATEWAY_ENV_WHITELIST' - env_whitelist = Unicode(default_value=env_whitelist_default_value, config=True, - help="""A comma-separated list of environment variable names that will be included, along with - their values, in the kernel startup request. The corresponding `env_whitelist` configuration - value must also be set on the Gateway server - since that configuration value indicates which - environmental values to make available to the kernel. (JUPYTER_GATEWAY_ENV_WHITELIST env var)""") - - @default('env_whitelist') - def _env_whitelist_default(self): - return os.environ.get(self.env_whitelist_env, self.env_whitelist_default_value) - - gateway_retry_interval_default_value = 1.0 - gateway_retry_interval_env = 'JUPYTER_GATEWAY_RETRY_INTERVAL' - gateway_retry_interval = Float(default_value=gateway_retry_interval_default_value, config=True, - help="""The time allowed for HTTP reconnection with the Gateway server for the first time. - Next will be JUPYTER_GATEWAY_RETRY_INTERVAL multiplied by two in factor of numbers of retries - but less than JUPYTER_GATEWAY_RETRY_INTERVAL_MAX. - (JUPYTER_GATEWAY_RETRY_INTERVAL env var)""") - - @default('gateway_retry_interval') - def gateway_retry_interval_default(self): - return float(os.environ.get('JUPYTER_GATEWAY_RETRY_INTERVAL', self.gateway_retry_interval_default_value)) - - gateway_retry_interval_max_default_value = 30.0 - gateway_retry_interval_max_env = 'JUPYTER_GATEWAY_RETRY_INTERVAL_MAX' - gateway_retry_interval_max = Float(default_value=gateway_retry_interval_max_default_value, config=True, - help="""The maximum time allowed for HTTP reconnection retry with the Gateway server. - (JUPYTER_GATEWAY_RETRY_INTERVAL_MAX env var)""") - - @default('gateway_retry_interval_max') - def gateway_retry_interval_max_default(self): - return float(os.environ.get('JUPYTER_GATEWAY_RETRY_INTERVAL_MAX', self.gateway_retry_interval_max_default_value)) - - gateway_retry_max_default_value = 5 - gateway_retry_max_env = 'JUPYTER_GATEWAY_RETRY_MAX' - gateway_retry_max = Int(default_value=gateway_retry_max_default_value, config=True, - help="""The maximum retries allowed for HTTP reconnection with the Gateway server. - (JUPYTER_GATEWAY_RETRY_MAX env var)""") - - @default('gateway_retry_max') - def gateway_retry_max_default(self): - return int(os.environ.get('JUPYTER_GATEWAY_RETRY_MAX', self.gateway_retry_max_default_value)) - - @property - def gateway_enabled(self): - return bool(self.url is not None and len(self.url) > 0) - - # Ensure KERNEL_LAUNCH_TIMEOUT has a default value. - KERNEL_LAUNCH_TIMEOUT = int(os.environ.get('KERNEL_LAUNCH_TIMEOUT', 40)) - - def init_static_args(self): - """Initialize arguments used on every request. Since these are static values, we'll - perform this operation once. - - """ - # Ensure that request timeout and KERNEL_LAUNCH_TIMEOUT are the same, taking the - # greater value of the two. - if self.request_timeout < float(GatewayClient.KERNEL_LAUNCH_TIMEOUT): - self.request_timeout = float(GatewayClient.KERNEL_LAUNCH_TIMEOUT) - elif self.request_timeout > float(GatewayClient.KERNEL_LAUNCH_TIMEOUT): - GatewayClient.KERNEL_LAUNCH_TIMEOUT = int(self.request_timeout) - # Ensure any adjustments are reflected in env. - os.environ['KERNEL_LAUNCH_TIMEOUT'] = str(GatewayClient.KERNEL_LAUNCH_TIMEOUT) - - self._static_args['headers'] = json.loads(self.headers) - if 'Authorization' not in self._static_args['headers'].keys(): - self._static_args['headers'].update({ - 'Authorization': 'token {}'.format(self.auth_token) - }) - self._static_args['connect_timeout'] = self.connect_timeout - self._static_args['request_timeout'] = self.request_timeout - self._static_args['validate_cert'] = self.validate_cert - if self.client_cert: - self._static_args['client_cert'] = self.client_cert - self._static_args['client_key'] = self.client_key - if self.ca_certs: - self._static_args['ca_certs'] = self.ca_certs - if self.http_user: - self._static_args['auth_username'] = self.http_user - if self.http_pwd: - self._static_args['auth_password'] = self.http_pwd - - def load_connection_args(self, **kwargs): - """Merges the static args relative to the connection, with the given keyword arguments. If statics - have yet to be initialized, we'll do that here. - """ - if len(self._static_args) == 0: - self.init_static_args() - - for arg, static_value in self._static_args.items(): - if arg == 'headers': - given_value = kwargs.setdefault(arg, {}) - if isinstance(given_value, dict): - given_value.update(static_value) - else: - kwargs[arg] = static_value - return kwargs - - -async def gateway_request(endpoint, **kwargs): - """Make an async request to kernel gateway endpoint, returns a response """ - client = AsyncHTTPClient() - kwargs = GatewayClient.instance().load_connection_args(**kwargs) - try: - response = await client.fetch(endpoint, **kwargs) - # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect - # or the server is not running. - # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes - # of the tree view. - except ConnectionRefusedError as e: - raise web.HTTPError(503, "Connection refused from Gateway server url '{}'. " - "Check to be sure the Gateway instance is running.".format(GatewayClient.instance().url)) from e - except HTTPError as e: - # This can occur if the host is valid (e.g., foo.com) but there's nothing there. - raise web.HTTPError(e.code, "Error attempting to connect to Gateway server url '{}'. " - "Ensure gateway url is valid and the Gateway instance is running.". - format(GatewayClient.instance().url)) from e - except gaierror as e: - raise web.HTTPError(404, "The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. " - "Ensure gateway url is valid and the Gateway instance is running.". - format(GatewayClient.instance().url)) from e - - return response - - -class GatewayKernelManager(AsyncMappingKernelManager): +class GatewayMappingKernelManager(AsyncMappingKernelManager): """Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway.""" # We'll maintain our own set of kernel ids - _kernels = {} + _kernels: Dict[str, 'GatewayKernelManager'] = {} - def __init__(self, **kwargs): - super(GatewayKernelManager, self).__init__(**kwargs) - self.base_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint) + @default('kernel_manager_class') + def _default_kernel_manager_class(self): + return "jupyter_server.gateway.managers.GatewayKernelManager" + + @default('shared_context') + def _default_shared_context(self): + return False # no need to share zmq contexts - def __contains__(self, kernel_id): - return kernel_id in self._kernels + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.kernels_url = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint) def remove_kernel(self, kernel_id): """Complete override since we want to be more tolerant of missing keys """ @@ -355,18 +52,6 @@ def remove_kernel(self, kernel_id): except KeyError: pass - def _get_kernel_endpoint_url(self, kernel_id=None): - """Builds a url for the kernels endpoint - - Parameters - ---------- - kernel_id : kernel UUID (optional) - """ - if kernel_id: - return url_path_join(self.base_endpoint, url_escape(str(kernel_id))) - - return self.base_endpoint - async def start_kernel(self, kernel_id=None, path=None, **kwargs): """Start a kernel for a session and return its kernel_id. @@ -385,68 +70,18 @@ async def start_kernel(self, kernel_id=None, path=None, **kwargs): if kernel_id is None: if path is not None: kwargs['cwd'] = self.cwd_for_path(path) - kernel_name = kwargs.get('kernel_name', 'python3') - kernel_url = self._get_kernel_endpoint_url() - self.log.debug(f"Request new kernel at: {kernel_url}") - - # Let KERNEL_USERNAME take precedent over http_user config option. - if os.environ.get('KERNEL_USERNAME') is None and GatewayClient.instance().http_user: - os.environ['KERNEL_USERNAME'] = GatewayClient.instance().http_user - - kernel_env = {k: v for (k, v) in dict(os.environ).items() if k.startswith('KERNEL_') - or k in GatewayClient.instance().env_whitelist.split(",")} - # Convey the full path to where this notebook file is located. - if path is not None and kernel_env.get('KERNEL_WORKING_DIR') is None: - kernel_env['KERNEL_WORKING_DIR'] = kwargs['cwd'] + km = self.kernel_manager_factory(parent=self, log=self.log) + await km.start_kernel(**kwargs) + kernel_id = km.kernel_id + self._kernels[kernel_id] = km - json_body = json_encode({'name': kernel_name, 'env': kernel_env}) - - response = await gateway_request( - kernel_url, method='POST', headers={'Content-Type': 'application/json'}, body=json_body - ) - kernel = json_decode(response.body) - kernel_id = kernel['id'] - self.log.info(f"Kernel started: {kernel_id}") - self.log.debug(f"Kernel args: {kwargs}") - else: - kernel = await self.get_kernel(kernel_id) - kernel_id = kernel['id'] - self.log.info(f"Using existing kernel: {kernel_id}") + # Initialize culling if not already + if not self._initialized_culler: + self.initialize_culler() - self._kernels[kernel_id] = kernel return kernel_id - async def get_kernel(self, kernel_id=None, **kwargs): - """Get kernel for kernel_id. - - Parameters - ---------- - kernel_id : uuid - The uuid of the kernel. - """ - kernel_url = self._get_kernel_endpoint_url(kernel_id) - self.log.debug(f"Request kernel at: {kernel_url}") - try: - response = await gateway_request(kernel_url, method='GET') - except web.HTTPError as error: - if error.status_code == 404: - self.log.warn(f"Kernel not found at: {kernel_url}") - self.remove_kernel(kernel_id) - kernel = None - else: - raise - else: - kernel = json_decode(response.body) - # Only update our models if we already know about this kernel - if kernel_id in self._kernels: - self._kernels[kernel_id] = kernel - self.log.debug(f"Kernel retrieved: {kernel}") - else: - self.log.warning(f"Kernel '{kernel_id}' is not managed by this instance.") - kernel = None - return kernel - async def kernel_model(self, kernel_id): """Return a dictionary of kernel information described in the JSON standard model. @@ -456,18 +91,38 @@ async def kernel_model(self, kernel_id): kernel_id : uuid The uuid of the kernel. """ - model = await self.get_kernel(kernel_id) + model = None + km = self.get_kernel(kernel_id) + if km: + model = km.kernel return model async def list_kernels(self, **kwargs): - """Get a list of kernels.""" - kernel_url = self._get_kernel_endpoint_url() - self.log.debug(f"Request list kernels: {kernel_url}") - response = await gateway_request(kernel_url, method='GET') + """Get a list of running kernels from the Gateway server. + + We'll use this opportunity to refresh the models in each of + the kernels we're managing. + """ + self.log.debug(f"Request list kernels: {self.kernels_url}") + response = await gateway_request(self.kernels_url, method='GET') kernels = json_decode(response.body) - # Only update our models if we already know about the kernels - self._kernels = {x['id']: x for x in kernels if x['id'] in self._kernels} - return list(self._kernels.values()) + # Refresh our models to those we know about, and filter + # the return value with only our kernels. + kernel_models = {} + for model in kernels: + kid = model['id'] + if kid in self._kernels: + await self._kernels[kid].refresh_model(model) + kernel_models[kid] = model + # Remove any of our kernels that may have been culled on the gateway server + our_kernels = self._kernels.copy() + culled_ids = [] + for kid, km in our_kernels.items(): + if kid not in kernel_models: + self.log.warn(f"Kernel {kid} no longer active - probably culled on Gateway server.") + self._kernels.pop(kid, None) + culled_ids.append(kid) # TODO: Figure out what do with these. + return list(kernel_models.values()) async def shutdown_kernel(self, kernel_id, now=False, restart=False): """Shutdown a kernel by its kernel uuid. @@ -481,10 +136,8 @@ async def shutdown_kernel(self, kernel_id, now=False, restart=False): restart : bool The purpose of this shutdown is to restart the kernel (True) """ - kernel_url = self._get_kernel_endpoint_url(kernel_id) - self.log.debug(f"Request shutdown kernel at: {kernel_url}") - response = await gateway_request(kernel_url, method='DELETE') - self.log.debug(f"Shutdown kernel response: {response.code} {response.reason}") + km = self.get_kernel(kernel_id) + await km.shutdown_kernel(now=now, restart=restart) self.remove_kernel(kernel_id) async def restart_kernel(self, kernel_id, now=False, **kwargs): @@ -495,12 +148,8 @@ async def restart_kernel(self, kernel_id, now=False, **kwargs): kernel_id : uuid The id of the kernel to restart. """ - kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/restart' - self.log.debug(f"Request restart kernel at: {kernel_url}") - response = await gateway_request( - kernel_url, method='POST', headers={'Content-Type': 'application/json'}, body=json_encode({}) - ) - self.log.debug(f"Restart kernel response: {response.code} {response.reason}") + km = self.get_kernel(kernel_id) + await km.restart_kernel(now=now, **kwargs) async def interrupt_kernel(self, kernel_id, **kwargs): """Interrupt a kernel by its kernel uuid. @@ -510,39 +159,26 @@ async def interrupt_kernel(self, kernel_id, **kwargs): kernel_id : uuid The id of the kernel to interrupt. """ - kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/interrupt' - self.log.debug(f"Request interrupt kernel at: {kernel_url}") - response = await gateway_request( - kernel_url, method='POST', headers={'Content-Type': 'application/json'}, body=json_encode({}) - ) - self.log.debug(f"Interrupt kernel response: {response.code} {response.reason}") - - def shutdown_all(self, now=False): + km = self.get_kernel(kernel_id) + await km.interrupt_kernel() + + async def shutdown_all(self, now=False): """Shutdown all kernels.""" - # Note: We have to make this sync because the NotebookApp does not wait for async. - shutdown_kernels = [] - kwargs = {'method': 'DELETE'} - kwargs = GatewayClient.instance().load_connection_args(**kwargs) - client = HTTPClient() for kernel_id in self._kernels: - kernel_url = self._get_kernel_endpoint_url(kernel_id) - self.log.debug(f"Request delete kernel at: {kernel_url}") - try: - response = client.fetch(kernel_url, **kwargs) - except HTTPError: - pass - else: - self.log.debug(f"Delete kernel response: {response.code} {response.reason}") - shutdown_kernels.append(kernel_id) # avoid changing dict size during iteration - client.close() - for kernel_id in shutdown_kernels: + km = self.get_kernel(kernel_id) + await km.shutdown_kernel(now=now) self.remove_kernel(kernel_id) + async def cull_kernels(self): + """Override cull_kernels so we can be sure their state is current. """ + await self.list_kernels() + await super().cull_kernels() + class GatewayKernelSpecManager(KernelSpecManager): def __init__(self, **kwargs): - super(GatewayKernelSpecManager, self).__init__(**kwargs) + super().__init__(**kwargs) base_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernelspecs_endpoint) @@ -646,9 +282,396 @@ async def get_kernel_spec_resource(self, kernel_name, path): class GatewaySessionManager(SessionManager): - kernel_manager = Instance('jupyter_server.gateway.managers.GatewayKernelManager') + kernel_manager = Instance('jupyter_server.gateway.managers.GatewayMappingKernelManager') async def kernel_culled(self, kernel_id): """Checks if the kernel is still considered alive and returns true if its not found. """ - kernel = await self.kernel_manager.get_kernel(kernel_id) + kernel = None + try: + km = self.kernel_manager.get_kernel(kernel_id) + kernel = await km.refresh_model() + except Exception: # Let exceptions here reflect culled kernel + pass return kernel is None + + +"""KernelManager class to manage a kernel running on a Gateway Server via the REST API""" + + +class GatewayKernelManager(AsyncKernelManager): + """Manages a single kernel remotely via a Gateway Server. """ + + kernel_id = None + kernel = None + + @default('cache_ports') + def _default_cache_ports(self): + return False # no need to cache ports here + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.kernels_url = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint) + self.kernel_url = self.kernel = self.kernel_id = None + # simulate busy/activity markers: + self.execution_state = self.last_activity = None + + @property + def has_kernel(self): + """Has a kernel been started that we are managing.""" + return self.kernel is not None + + client_class = DottedObjectName('jupyter_server.gateway.managers.GatewayKernelClient') + client_factory = Type(klass='jupyter_server.gateway.managers.GatewayKernelClient') + + # -------------------------------------------------------------------------- + # create a Client connected to our Kernel + # -------------------------------------------------------------------------- + + def client(self, **kwargs): + """Create a client configured to connect to our kernel""" + kw = {} + kw.update(self.get_connection_info(session=True)) + kw.update(dict( + connection_file=self.connection_file, + parent=self, + )) + kw['kernel_id'] = self.kernel_id + + # add kwargs last, for manual overrides + kw.update(kwargs) + return self.client_factory(**kw) + + async def refresh_model(self, model=None): + """Refresh the kernel model. + + Parameters + ---------- + model : dict + The model from which to refresh the kernel. If None, the kernel + model is fetched from the Gateway server. + """ + if model is None: + self.log.debug("Request kernel at: %s" % self.kernel_url) + try: + response = await gateway_request(self.kernel_url, method='GET') + except web.HTTPError as error: + if error.status_code == 404: + self.log.warning("Kernel not found at: %s" % self.kernel_url) + model = None + else: + raise + else: + model = json_decode(response.body) + self.log.debug("Kernel retrieved: %s" % model) + + if model: # Update activity markers + self.last_activity = datetime.datetime.strptime( + model['last_activity'], '%Y-%m-%dT%H:%M:%S.%fZ').replace(tzinfo=UTC) + self.execution_state = model['execution_state'] + if isinstance(self.parent, AsyncMappingKernelManager): + # Update connections only if there's a mapping kernel manager parent for + # this kernel manager. The current kernel manager instance may not have + # an parent instance if, say, a server extension is using another application + # (e.g., papermill) that uses a KernelManager instance directly. + self.parent._kernel_connections[self.kernel_id] = int(model['connections']) + + self.kernel = model + return model + + # -------------------------------------------------------------------------- + # Kernel management + # -------------------------------------------------------------------------- + + async def start_kernel(self, **kwargs): + """Starts a kernel via HTTP in an asynchronous manner. + + Parameters + ---------- + `**kwargs` : optional + keyword arguments that are passed down to build the kernel_cmd + and launching the kernel (e.g. Popen kwargs). + """ + kernel_id = kwargs.get('kernel_id') + + if kernel_id is None: + kernel_name = kwargs.get('kernel_name', 'python3') + self.log.debug("Request new kernel at: %s" % self.kernels_url) + + # Let KERNEL_USERNAME take precedent over http_user config option. + if os.environ.get('KERNEL_USERNAME') is None and GatewayClient.instance().http_user: + os.environ['KERNEL_USERNAME'] = GatewayClient.instance().http_user + + kernel_env = {k: v for (k, v) in dict(os.environ).items() if k.startswith('KERNEL_') or + k in GatewayClient.instance().env_whitelist.split(",")} + + # Add any env entries in this request + kernel_env.update(kwargs.get('env', {})) + + # Convey the full path to where this notebook file is located. + if kwargs.get('cwd') is not None and kernel_env.get('KERNEL_WORKING_DIR') is None: + kernel_env['KERNEL_WORKING_DIR'] = kwargs['cwd'] + + json_body = json_encode({'name': kernel_name, 'env': kernel_env}) + + response = await gateway_request(self.kernels_url, method='POST', body=json_body) + self.kernel = json_decode(response.body) + self.kernel_id = self.kernel['id'] + self.log.info("GatewayKernelManager started kernel: {}, args: {}".format(self.kernel_id, kwargs)) + else: + self.kernel_id = kernel_id + self.kernel = await self.refresh_model() + self.log.info("GatewayKernelManager using existing kernel: {}".format(self.kernel_id)) + + self.kernel_url = url_path_join(self.kernels_url, url_escape(str(self.kernel_id))) + + async def shutdown_kernel(self, now=False, restart=False): + """Attempts to stop the kernel process cleanly via HTTP. """ + + if self.has_kernel: + self.log.debug("Request shutdown kernel at: %s", self.kernel_url) + response = await gateway_request(self.kernel_url, method='DELETE') + self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason) + + async def restart_kernel(self, **kw): + """Restarts a kernel via HTTP. """ + if self.has_kernel: + kernel_url = self.kernel_url + '/restart' + self.log.debug("Request restart kernel at: %s", kernel_url) + response = await gateway_request(kernel_url, method='POST', body=json_encode({})) + self.log.debug("Restart kernel response: %d %s", response.code, response.reason) + + async def interrupt_kernel(self): + """Interrupts the kernel via an HTTP request. """ + if self.has_kernel: + kernel_url = self.kernel_url + '/interrupt' + self.log.debug("Request interrupt kernel at: %s", kernel_url) + response = await gateway_request(kernel_url, method='POST', body=json_encode({})) + self.log.debug("Interrupt kernel response: %d %s", response.code, response.reason) + + async def is_alive(self): + """Is the kernel process still running?""" + if self.has_kernel: + # Go ahead and issue a request to get the kernel + self.kernel = await self.refresh_model() + return True + else: # we don't have a kernel + return False + + def cleanup_resources(self, restart=False): + """Clean up resources when the kernel is shut down""" + pass + + +KernelManagerABC.register(GatewayKernelManager) + + +class ChannelQueue(Queue): + + channel_name: str = None + + def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: Logger): + super().__init__() + self.channel_name = channel_name + self.channel_socket = channel_socket + self.log = log + + async def get_msg(self, *args, **kwargs) -> dict: + timeout = kwargs.get('timeout', 1) + msg = self.get(timeout=timeout) + self.log.debug("Received message on channel: {}, msg_id: {}, msg_type: {}". + format(self.channel_name, msg['msg_id'], msg['msg_type'] if msg else 'null')) + self.task_done() + return msg + + def send(self, msg: dict) -> None: + message = json.dumps(msg, default=ChannelQueue.serialize_datetime).replace(" None: + pass + + def stop(self) -> None: + if not self.empty(): + # If unprocessed messages are detected, drain the queue collecting non-status + # messages. If any remain that are not 'shutdown_reply' and this is not iopub + # go ahead and issue a warning. + msgs = [] + while self.qsize(): + msg = self.get_nowait() + if msg['msg_type'] != 'status': + msgs.append(msg['msg_type']) + if self.channel_name == 'iopub' and 'shutdown_reply' in msgs: + return + if len(msgs): + self.log.warning("Stopping channel '{}' with {} unprocessed non-status messages: {}.". + format(self.channel_name, len(msgs), msgs)) + + def is_alive(self) -> bool: + return self.channel_socket is not None + + +class HBChannelQueue(ChannelQueue): + + def is_beating(self) -> bool: + # Just use the is_alive status for now + return self.is_alive() + + +class GatewayKernelClient(AsyncKernelClient): + """Communicates with a single kernel indirectly via a websocket to a gateway server. + + There are five channels associated with each kernel: + + * shell: for request/reply calls to the kernel. + * iopub: for the kernel to publish results to frontends. + * hb: for monitoring the kernel's heartbeat. + * stdin: for frontends to reply to raw_input calls in the kernel. + * control: for kernel management calls to the kernel. + + The messages that can be sent on these channels are exposed as methods of the + client (KernelClient.execute, complete, history, etc.). These methods only + send the message, they don't wait for a reply. To get results, use e.g. + :meth:`get_shell_msg` to fetch messages from the shell channel. + """ + + # flag for whether execute requests should be allowed to call raw_input: + allow_stdin = False + _channels_stopped = False + _channel_queues = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.kernel_id = kwargs['kernel_id'] + self.channel_socket = None + self.response_router = None + + # -------------------------------------------------------------------------- + # Channel management methods + # -------------------------------------------------------------------------- + + async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True): + """Starts the channels for this kernel. + + For this class, we establish a websocket connection to the destination + and setup the channel-based queues on which applicable messages will + be posted. + """ + + ws_url = url_path_join( + GatewayClient.instance().ws_url, + GatewayClient.instance().kernels_endpoint, url_escape(self.kernel_id), 'channels') + # Gather cert info in case where ssl is desired... + ssl_options = dict() + ssl_options['ca_certs'] = GatewayClient.instance().ca_certs + ssl_options['certfile'] = GatewayClient.instance().client_cert + ssl_options['keyfile'] = GatewayClient.instance().client_key + + self.channel_socket = websocket.create_connection(ws_url, + timeout=GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT, + enable_multithread=True, + sslopt=ssl_options) + self.response_router = Thread(target=self._route_responses) + self.response_router.start() + + await ensure_async(super().start_channels(shell=shell, iopub=iopub, stdin=stdin, hb=hb, control=control)) + + def stop_channels(self): + """Stops all the running channels for this kernel. + + For this class, we close the websocket connection and destroy the + channel-based queues. + """ + super().stop_channels() + self._channels_stopped = True + self.log.debug("Closing websocket connection") + + self.channel_socket.close() + self.response_router.join() + + if self._channel_queues: + self._channel_queues.clear() + self._channel_queues = None + + # Channels are implemented via a ChannelQueue that is used to send and receive messages + + @property + def shell_channel(self): + """Get the shell channel object for this kernel.""" + if self._shell_channel is None: + self.log.debug("creating shell channel queue") + self._shell_channel = ChannelQueue('shell', self.channel_socket, self.log) + self._channel_queues['shell'] = self._shell_channel + return self._shell_channel + + @property + def iopub_channel(self): + """Get the iopub channel object for this kernel.""" + if self._iopub_channel is None: + self.log.debug("creating iopub channel queue") + self._iopub_channel = ChannelQueue('iopub', self.channel_socket, self.log) + self._channel_queues['iopub'] = self._iopub_channel + return self._iopub_channel + + @property + def stdin_channel(self): + """Get the stdin channel object for this kernel.""" + if self._stdin_channel is None: + self.log.debug("creating stdin channel queue") + self._stdin_channel = ChannelQueue('stdin', self.channel_socket, self.log) + self._channel_queues['stdin'] = self._stdin_channel + return self._stdin_channel + + @property + def hb_channel(self): + """Get the hb channel object for this kernel.""" + if self._hb_channel is None: + self.log.debug("creating hb channel queue") + self._hb_channel = HBChannelQueue('hb', self.channel_socket, self.log) + self._channel_queues['hb'] = self._hb_channel + return self._hb_channel + + @property + def control_channel(self): + """Get the control channel object for this kernel.""" + if self._control_channel is None: + self.log.debug("creating control channel queue") + self._control_channel = ChannelQueue('control', self.channel_socket, self.log) + self._channel_queues['control'] = self._control_channel + return self._control_channel + + def _route_responses(self): + """ + Reads responses from the websocket and routes each to the appropriate channel queue based + on the message's channel. It does this for the duration of the class's lifetime until the + channels are stopped, at which time the socket is closed (unblocking the router) and + the thread terminates. If shutdown happens to occur while processing a response (unlikely), + termination takes place via the loop control boolean. + """ + try: + while not self._channels_stopped: + raw_message = self.channel_socket.recv() + if not raw_message: + break + response_message = json_decode(utf8(raw_message)) + channel = response_message['channel'] + self._channel_queues[channel].put_nowait(response_message) + + except websocket.WebSocketConnectionClosedException: + pass # websocket closure most likely due to shutdown + + except BaseException as be: + if not self._channels_stopped: + self.log.warning('Unexpected exception encountered ({})'.format(be)) + + self.log.debug('Response router thread exiting...') + + +KernelClientABC.register(GatewayKernelClient) diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 24c1227e6e..d8b6df5899 100755 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -75,7 +75,7 @@ from jupyter_server.services.contents.filemanager import AsyncFileContentsManager, FileContentsManager from jupyter_server.services.contents.largefilemanager import LargeFileManager from jupyter_server.services.sessions.sessionmanager import SessionManager -from jupyter_server.gateway.managers import GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient +from jupyter_server.gateway.managers import GatewayMappingKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient from jupyter_server.auth.login import LoginHandler from jupyter_server.auth.logout import LogoutHandler @@ -592,7 +592,7 @@ class ServerApp(JupyterApp): classes = [ KernelManager, Session, MappingKernelManager, KernelSpecManager, AsyncMappingKernelManager, ContentsManager, FileContentsManager, AsyncContentsManager, AsyncFileContentsManager, NotebookNotary, - GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient + GatewayMappingKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient ] if terminado_available: # Only necessary when terminado is available classes.append(TerminalManager) @@ -1405,7 +1405,7 @@ def init_configurables(self): self.gateway_config = GatewayClient.instance(parent=self) if self.gateway_config.gateway_enabled: - self.kernel_manager_class = 'jupyter_server.gateway.managers.GatewayKernelManager' + self.kernel_manager_class = 'jupyter_server.gateway.managers.GatewayMappingKernelManager' self.session_manager_class = 'jupyter_server.gateway.managers.GatewaySessionManager' self.kernel_spec_manager_class = 'jupyter_server.gateway.managers.GatewayKernelSpecManager' diff --git a/jupyter_server/tests/test_gateway.py b/jupyter_server/tests/test_gateway.py index e6311e6f7a..6c18bba28f 100644 --- a/jupyter_server/tests/test_gateway.py +++ b/jupyter_server/tests/test_gateway.py @@ -186,7 +186,7 @@ async def test_gateway_cli_options(jp_configurable_serverapp): async def test_gateway_class_mappings(init_gateway, jp_serverapp): # Ensure appropriate class mappings are in place. - assert jp_serverapp.kernel_manager_class.__name__ == 'GatewayKernelManager' + assert jp_serverapp.kernel_manager_class.__name__ == 'GatewayMappingKernelManager' assert jp_serverapp.session_manager_class.__name__ == 'GatewaySessionManager' assert jp_serverapp.kernel_spec_manager_class.__name__ == 'GatewayKernelSpecManager' diff --git a/setup.cfg b/setup.cfg index 4cd951aca2..a4db978986 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,7 @@ install_requires = pywin32>=1.0 ; sys_platform == 'win32' anyio>=2.0.2,<3 ; python_version < '3.7' anyio>=3.0.1,<4 ; python_version >= '3.7' + websocket-client [options.extras_require] test = coverage; pytest; pytest-cov; pytest-mock; requests; pytest-tornasync; pytest-console-scripts; ipykernel