diff --git a/docs/source/developers/extensions.rst b/docs/source/developers/extensions.rst index 878f0287e0..b056746443 100644 --- a/docs/source/developers/extensions.rst +++ b/docs/source/developers/extensions.rst @@ -156,14 +156,19 @@ The basic structure of an ExtensionApp is shown below: ... # Change the jinja templating environment + async def stop_extension(self): + ... + # Perform any required shut down steps + The ``ExtensionApp`` uses the following methods and properties to connect your extension to the Jupyter server. You do not need to define a ``_load_jupyter_server_extension`` function for these apps. Instead, overwrite the pieces below to add your custom settings, handlers and templates: Methods -* ``initialize_setting()``: adds custom settings to the Tornado Web Application. +* ``initialize_settings()``: adds custom settings to the Tornado Web Application. * ``initialize_handlers()``: appends handlers to the Tornado Web Application. * ``initialize_templates()``: initialize the templating engine (e.g. jinja2) for your frontend. +* ``stop_extension()``: called on server shut down. Properties diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index e5cb6bcafd..76ef17dd07 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -416,6 +416,9 @@ def start(self): # Start the server. self.serverapp.start() + async def stop_extension(self): + """Cleanup any resources managed by this extension.""" + def stop(self): """Stop the underlying Jupyter server. """ diff --git a/jupyter_server/extension/manager.py b/jupyter_server/extension/manager.py index 3e0d8e1bd6..83f1af7943 100644 --- a/jupyter_server/extension/manager.py +++ b/jupyter_server/extension/manager.py @@ -2,6 +2,8 @@ import sys import traceback +from tornado.gen import multi + from traitlets.config import LoggingConfigurable from traitlets import ( @@ -230,15 +232,17 @@ def link_point(self, point_name, serverapp): def load_point(self, point_name, serverapp): point = self.extension_points[point_name] - point.load(serverapp) + return point.load(serverapp) def link_all_points(self, serverapp): for point_name in self.extension_points: self.link_point(point_name, serverapp) def load_all_points(self, serverapp): - for point_name in self.extension_points: + return [ self.load_point(point_name, serverapp) + for point_name in self.extension_points + ] class ExtensionManager(LoggingConfigurable): @@ -290,12 +294,26 @@ def sorted_extensions(self): """ ) + @property + def extension_apps(self): + """Return mapping of extension names and sets of ExtensionApp objects. + """ + return { + name: { + point.app + for point in extension.extension_points.values() + if point.app + } + for name, extension in self.extensions.items() + } + @property def extension_points(self): - extensions = self.extensions + """Return mapping of extension point names and ExtensionPoint objects. + """ return { name: point - for value in extensions.values() + for value in self.extensions.values() for name, point in value.extension_points.items() } @@ -341,13 +359,22 @@ def link_extension(self, name, serverapp): def load_extension(self, name, serverapp): extension = self.extensions.get(name) + if extension.enabled: try: - extension.load_all_points(serverapp) - self.log.info("{name} | extension was successfully loaded.".format(name=name)) + points = extension.load_all_points(serverapp) except Exception as e: self.log.debug("".join(traceback.format_exception(*sys.exc_info()))) self.log.warning("{name} | extension failed loading with message: {error}".format(name=name,error=str(e))) + else: + self.log.info("{name} | extension was successfully loaded.".format(name=name)) + + async def stop_extension(self, name, apps): + """Call the shutdown hooks in the specified apps.""" + for app in apps: + self.log.debug('{} | extension app "{}" stopping'.format(name, app.name)) + await app.stop_extension() + self.log.debug('{} | extension app "{}" stopped'.format(name, app.name)) def link_all_extensions(self, serverapp): """Link all enabled extensions @@ -366,3 +393,10 @@ def load_all_extensions(self, serverapp): # order. for name in self.sorted_extensions.keys(): self.load_extension(name, serverapp) + + async def stop_all_extensions(self, serverapp): + """Call the shutdown hooks in all extensions.""" + await multi([ + self.stop_extension(name, apps) + for name, apps in sorted(dict(self.extension_apps).items()) + ]) diff --git a/jupyter_server/pytest_plugin.py b/jupyter_server/pytest_plugin.py index 892828ba5f..00a7c8fc86 100644 --- a/jupyter_server/pytest_plugin.py +++ b/jupyter_server/pytest_plugin.py @@ -18,7 +18,7 @@ from jupyter_server.extension import serverextension from jupyter_server.serverapp import ServerApp -from jupyter_server.utils import url_path_join +from jupyter_server.utils import url_path_join, run_sync from jupyter_server.services.contents.filemanager import FileContentsManager from jupyter_server.services.contents.largefilemanager import LargeFileManager @@ -284,7 +284,7 @@ def jp_serverapp( """Starts a Jupyter Server instance based on the established configuration values.""" app = jp_configurable_serverapp(config=jp_server_config, argv=jp_argv) yield app - app._cleanup() + run_sync(app._cleanup()) @pytest.fixture diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index dd9939ca5d..03fd748902 100755 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -43,7 +43,7 @@ from jupyter_core.paths import secure_write from jupyter_server.transutils import trans, _i18n -from jupyter_server.utils import run_sync +from jupyter_server.utils import run_sync_in_loop # the minimum viable tornado version: needs to be kept in sync with setup.py MIN_TORNADO = (6, 1, 0) @@ -1750,7 +1750,7 @@ def _confirm_exit(self): self.log.critical(_i18n("Shutting down...")) # schedule stop on the main thread, # since this might be called from a signal handler - self.io_loop.add_callback_from_signal(self.io_loop.stop) + self.stop(from_signal=True) return print(self.running_server_info()) yes = _i18n('y') @@ -1764,7 +1764,7 @@ def _confirm_exit(self): self.log.critical(_i18n("Shutdown confirmed")) # schedule stop on the main thread, # since this might be called from a signal handler - self.io_loop.add_callback_from_signal(self.io_loop.stop) + self.stop(from_signal=True) return else: print(_i18n("No answer for 5s:"), end=' ') @@ -1777,7 +1777,7 @@ def _confirm_exit(self): def _signal_stop(self, sig, frame): self.log.critical(_i18n("received signal %s, stopping"), sig) - self.io_loop.add_callback_from_signal(self.io_loop.stop) + self.stop(from_signal=True) def _signal_info(self, sig, frame): print(self.running_server_info()) @@ -2059,7 +2059,7 @@ def initialize(self, argv=None, find_extensions=True, new_httpserver=True, start if new_httpserver: self.init_httpserver() - def cleanup_kernels(self): + async def cleanup_kernels(self): """Shutdown all kernels. The kernels will shutdown themselves when this process no longer exists, @@ -2068,9 +2068,9 @@ def cleanup_kernels(self): n_kernels = len(self.kernel_manager.list_kernel_ids()) kernel_msg = trans.ngettext('Shutting down %d kernel', 'Shutting down %d kernels', n_kernels) self.log.info(kernel_msg % n_kernels) - run_sync(self.kernel_manager.shutdown_all()) + await run_sync_in_loop(self.kernel_manager.shutdown_all()) - def cleanup_terminals(self): + async def cleanup_terminals(self): """Shutdown all terminals. The terminals will shutdown themselves when this process no longer exists, @@ -2083,7 +2083,20 @@ def cleanup_terminals(self): n_terminals = len(terminal_manager.list()) terminal_msg = trans.ngettext('Shutting down %d terminal', 'Shutting down %d terminals', n_terminals) self.log.info(terminal_msg % n_terminals) - run_sync(terminal_manager.terminate_all()) + await run_sync_in_loop(terminal_manager.terminate_all()) + + async def cleanup_extensions(self): + """Call shutdown hooks in all extensions.""" + n_extensions = len(self.extension_manager.extension_apps) + extension_msg = trans.ngettext( + 'Shutting down %d extension', + 'Shutting down %d extensions', + n_extensions + ) + self.log.info(extension_msg % n_extensions) + await run_sync_in_loop( + self.extension_manager.stop_all_extensions(self) + ) def running_server_info(self, kernel_count=True): "Return the current working directory and the server url information" @@ -2321,14 +2334,15 @@ def start_app(self): ' %s' % self.display_url, ])) - def _cleanup(self): - """General cleanup of files and kernels created + async def _cleanup(self): + """General cleanup of files, extensions and kernels created by this instance ServerApp. """ self.remove_server_info_file() self.remove_browser_open_files() - self.cleanup_kernels() - self.cleanup_terminals() + await self.cleanup_extensions() + await self.cleanup_kernels() + await self.cleanup_terminals() def start_ioloop(self): """Start the IO Loop.""" @@ -2341,8 +2355,6 @@ def start_ioloop(self): self.io_loop.start() except KeyboardInterrupt: self.log.info(_i18n("Interrupted...")) - finally: - self._cleanup() def init_ioloop(self): """init self.io_loop so that an extension can use it by io_loop.call_later() to create background tasks""" @@ -2356,13 +2368,23 @@ def start(self): self.start_app() self.start_ioloop() - def stop(self): - def _stop(): + async def _stop(self): + """Cleanup resources and stop the IO Loop.""" + await self._cleanup() + self.io_loop.stop() + + def stop(self, from_signal=False): + """Cleanup resources and stop the server.""" + if hasattr(self, '_http_server'): # Stop a server if its set. - if hasattr(self, '_http_server'): - self.http_server.stop() - self.io_loop.stop() - self.io_loop.add_callback(_stop) + self.http_server.stop() + if getattr(self, 'io_loop', None): + # use IOLoop.add_callback because signal.signal must be called + # from main thread + if from_signal: + self.io_loop.add_callback_from_signal(self._stop) + else: + self.io_loop.add_callback(self._stop) def list_running_servers(runtime_dir=None): diff --git a/jupyter_server/tests/extension/test_app.py b/jupyter_server/tests/extension/test_app.py index 3cc0e82fe6..fe83d24ba2 100644 --- a/jupyter_server/tests/extension/test_app.py +++ b/jupyter_server/tests/extension/test_app.py @@ -1,6 +1,7 @@ import pytest from traitlets.config import Config from jupyter_server.serverapp import ServerApp +from jupyter_server.utils import run_sync from .mockextensions.app import MockExtensionApp @@ -101,3 +102,42 @@ def test_load_parallel_extensions(monkeypatch, jp_environ): exts = serverapp.jpserver_extensions assert exts['jupyter_server.tests.extension.mockextensions.mock1'] assert exts['jupyter_server.tests.extension.mockextensions'] + + +def test_stop_extension(jp_serverapp, caplog): + """Test the stop_extension method. + + This should be fired by ServerApp.cleanup_extensions. + """ + calls = 0 + + # load extensions (make sure we only have the one extension loaded + jp_serverapp.extension_manager.load_all_extensions(jp_serverapp) + extension_name = 'jupyter_server.tests.extension.mockextensions' + assert list(jp_serverapp.extension_manager.extension_apps) == [ + extension_name + ] + + # add a stop_extension method for the extension app + async def _stop(*args): + nonlocal calls + calls += 1 + for apps in jp_serverapp.extension_manager.extension_apps.values(): + for app in apps: + if app: + app.stop_extension = _stop + + # call cleanup_extensions, check the logging is correct + caplog.clear() + run_sync(jp_serverapp.cleanup_extensions()) + assert [ + msg + for *_, msg in caplog.record_tuples + ] == [ + 'Shutting down 1 extension', + '{} | extension app "mockextension" stopping'.format(extension_name), + '{} | extension app "mockextension" stopped'.format(extension_name), + ] + + # check the shutdown method was called once + assert calls == 1 diff --git a/jupyter_server/utils.py b/jupyter_server/utils.py index e0a532bbdd..83fdd43611 100644 --- a/jupyter_server/utils.py +++ b/jupyter_server/utils.py @@ -230,6 +230,29 @@ def wrapped(): return wrapped() +async def run_sync_in_loop(maybe_async): + """Runs a function synchronously whether it is an async function or not. + + If async, runs maybe_async and blocks until it has executed. + + If not async, just returns maybe_async as it is the result of something + that has already executed. + + Parameters + ---------- + maybe_async : async or non-async object + The object to be executed, if it is async. + + Returns + ------- + result + Whatever the async object returns, or the object itself. + """ + if not inspect.isawaitable(maybe_async): + return maybe_async + return await maybe_async + + def urlencode_unix_socket_path(socket_path): """Encodes a UNIX socket path string from a socket path for the `http+unix` URI form.""" return socket_path.replace('/', '%2F')