Skip to content

Commit

Permalink
Resolve conflict due to anyio updates
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin-bates committed May 3, 2021
2 parents 827fc6a + a0135ca commit f4ba1ce
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 89 deletions.
89 changes: 51 additions & 38 deletions jupyter_server/extension/manager.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import importlib

from traitlets.config import LoggingConfigurable, Config
from traitlets.config import LoggingConfigurable

from traitlets import (
HasTraits,
Dict,
Unicode,
Bool,
Any,
validate
Instance,
default,
observe,
validate,
)

from .config import ExtensionConfigManager
from .utils import (
ExtensionMetadataError,
ExtensionModuleNotFound,
Expand Down Expand Up @@ -129,6 +133,8 @@ def validate(self):
self._get_loader()
except Exception:
return False
else:
return True

def link(self, serverapp):
"""Link the extension to a Jupyter ServerApp object.
Expand Down Expand Up @@ -238,35 +244,44 @@ class ExtensionManager(LoggingConfigurable):
linking, loading, and managing Jupyter Server extensions.
Usage:
m = ExtensionManager(jpserver_extensions=extensions)
m = ExtensionManager(config_manager=...)
"""
def __init__(self, config_manager=None, *args, **kwargs):
super().__init__(*args, **kwargs)
# The `enabled_extensions` attribute provides a dictionary
# with extension (package) names mapped to their ExtensionPackage interface
# (see above). This manager simplifies the interaction between the
# ServerApp and the extensions being appended.
self._extensions = {}
# The `_linked_extensions` attribute tracks when each extension
# has been successfully linked to a ServerApp. This helps prevent
# extensions from being re-linked recursively unintentionally if another
# extension attempts to link extensions again.
self._linked_extensions = {}
self._config_manager = config_manager
if self._config_manager:
self.from_config_manager(self._config_manager)

@property
def config_manager(self):
return self._config_manager
config_manager = Instance(ExtensionConfigManager, allow_none=True)

@default("config_manager")
def _load_default_config_manager(self):
config_manager = ExtensionConfigManager()
self._load_config_manager(config_manager)
return config_manager

@observe("config_manager")
def _config_manager_changed(self, change):
if change.new:
self._load_config_manager(change.new)

# The `extensions` attribute provides a dictionary
# with extension (package) names mapped to their ExtensionPackage interface
# (see above). This manager simplifies the interaction between the
# ServerApp and the extensions being appended.
extensions = Dict(
help="""
Dictionary with extension package names as keys
and ExtensionPackage objects as values.
"""
)

@property
def extensions(self):
"""Dictionary with extension package names as keys
and an ExtensionPackage objects as values.
# The `_linked_extensions` attribute tracks when each extension
# has been successfully linked to a ServerApp. This helps prevent
# extensions from being re-linked recursively unintentionally if another
# extension attempts to link extensions again.
linked_extensions = Dict(
help="""
Dictionary with extension names as keys
values are True if the extension is linked, False if not.
"""
# Sort enabled extensions before
return self._extensions
)

@property
def extension_points(self):
Expand All @@ -277,16 +292,14 @@ def extension_points(self):
for name, point in value.extension_points.items()
}

@property
def linked_extensions(self):
"""Dictionary with extension names as keys; values are
True if the extension is linked, False if not."""
return self._linked_extensions

def from_config_manager(self, config_manager):
"""Add extensions found by an ExtensionConfigManager"""
self._config_manager = config_manager
jpserver_extensions = self._config_manager.get_jpserver_extensions()
# load triggered via config_manager trait observer
self.config_manager = config_manager

def _load_config_manager(self, config_manager):
"""Actually load our config manager"""
jpserver_extensions = config_manager.get_jpserver_extensions()
self.from_jpserver_extensions(jpserver_extensions)

def from_jpserver_extensions(self, jpserver_extensions):
Expand All @@ -300,21 +313,21 @@ def add_extension(self, extension_name, enabled=False):
"""
try:
extpkg = ExtensionPackage(name=extension_name, enabled=enabled)
self._extensions[extension_name] = extpkg
self.extensions[extension_name] = extpkg
return True
# Raise a warning if the extension cannot be loaded.
except Exception as e:
self.log.warning(e)
return False

def link_extension(self, name, serverapp):
linked = self._linked_extensions.get(name, False)
linked = self.linked_extensions.get(name, False)
extension = self.extensions[name]
if not linked and extension.enabled:
try:
# Link extension and store links
extension.link_all_points(serverapp)
self._linked_extensions[name] = True
self.linked_extensions[name] = True
self.log.info("{name} | extension was successfully linked.".format(name=name))
except Exception as e:
self.log.warning(e)
Expand Down
3 changes: 2 additions & 1 deletion jupyter_server/extension/serverextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ def list_server_extensions(self):
GREEN_ENABLED if enabled else RED_DISABLED))
try:
self.log.info(" - Validating {}...".format(name))
extension.validate()
if not extension.validate():
raise ValueError("validation failed")
version = extension.version
self.log.info(
" {} {} {}".format(name, version, GREEN_OK)
Expand Down
13 changes: 9 additions & 4 deletions jupyter_server/services/contents/filecheckpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
)
from .fileio import AsyncFileManagerMixin, FileManagerMixin

from anyio import run_sync_in_worker_thread
try:
from anyio.to_thread import run_sync
except ImportError:
# fallback on anyio v2 for python version < 3.7
from anyio import run_sync_in_worker_thread as run_sync

from jupyter_core.utils import ensure_dir_exists
from traitlets import Unicode

Expand Down Expand Up @@ -156,7 +161,7 @@ async def restore_checkpoint(self, contents_mgr, checkpoint_id, path):

async def checkpoint_model(self, checkpoint_id, os_path):
"""construct the info dict for a given checkpoint"""
stats = await run_sync_in_worker_thread(os.stat, os_path)
stats = await run_sync(os.stat, os_path)
last_modified = tz.utcfromtimestamp(stats.st_mtime)
info = dict(
id=checkpoint_id,
Expand All @@ -176,7 +181,7 @@ async def rename_checkpoint(self, checkpoint_id, old_path, new_path):
new_cp_path,
)
with self.perm_to_403():
await run_sync_in_worker_thread(shutil.move, old_cp_path, new_cp_path)
await run_sync(shutil.move, old_cp_path, new_cp_path)

async def delete_checkpoint(self, checkpoint_id, path):
"""delete a file's checkpoint"""
Expand All @@ -187,7 +192,7 @@ async def delete_checkpoint(self, checkpoint_id, path):

self.log.debug("unlinking %s", cp_path)
with self.perm_to_403():
await run_sync_in_worker_thread(os.unlink, cp_path)
await run_sync(os.unlink, cp_path)

async def list_checkpoints(self, path):
"""list the checkpoints for a given file
Expand Down
21 changes: 13 additions & 8 deletions jupyter_server/services/contents/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import os
import shutil

from anyio import run_sync_in_worker_thread
try:
from anyio.to_thread import run_sync
except ImportError:
# fallback on anyio v2 for python version < 3.7
from anyio import run_sync_in_worker_thread as run_sync

from tornado.web import HTTPError

from jupyter_server.utils import (
Expand All @@ -36,7 +41,7 @@ def replace_file(src, dst):
async def async_replace_file(src, dst):
""" replace dst with src asynchronously
"""
await run_sync_in_worker_thread(os.replace, src, dst)
await run_sync(os.replace, src, dst)

def copy2_safe(src, dst, log=None):
"""copy src to dst
Expand All @@ -55,9 +60,9 @@ async def async_copy2_safe(src, dst, log=None):
like shutil.copy2, but log errors in copystat instead of raising
"""
await run_sync_in_worker_thread(shutil.copyfile, src, dst)
await run_sync(shutil.copyfile, src, dst)
try:
await run_sync_in_worker_thread(shutil.copystat, src, dst)
await run_sync(shutil.copystat, src, dst)
except OSError:
if log:
log.debug("copystat on %s failed", dst, exc_info=True)
Expand Down Expand Up @@ -355,7 +360,7 @@ async def _read_notebook(self, os_path, as_version=4):
"""Read a notebook from an os path."""
with self.open(os_path, 'r', encoding='utf-8') as f:
try:
return await run_sync_in_worker_thread(partial(nbformat.read, as_version=as_version), f)
return await run_sync(partial(nbformat.read, as_version=as_version), f)
except Exception as e:
e_orig = e

Expand All @@ -379,7 +384,7 @@ async def _read_notebook(self, os_path, as_version=4):
async def _save_notebook(self, os_path, nb):
"""Save a notebook to an os_path."""
with self.atomic_writing(os_path, encoding='utf-8') as f:
await run_sync_in_worker_thread(partial(nbformat.write, version=nbformat.NO_CONVERT), nb, f)
await run_sync(partial(nbformat.write, version=nbformat.NO_CONVERT), nb, f)

async def _read_file(self, os_path, format):
"""Read a non-notebook file.
Expand All @@ -394,7 +399,7 @@ async def _read_file(self, os_path, format):
raise HTTPError(400, "Cannot read non-file %s" % os_path)

with self.open(os_path, 'rb') as f:
bcontent = await run_sync_in_worker_thread(f.read)
bcontent = await run_sync(f.read)

if format is None or format == 'text':
# Try to interpret as unicode if format is unknown or if unicode
Expand Down Expand Up @@ -429,4 +434,4 @@ async def _save_file(self, os_path, content, format):
) from e

with self.atomic_writing(os_path, text=False) as f:
await run_sync_in_worker_thread(f.write, bcontent)
await run_sync(f.write, bcontent)
25 changes: 15 additions & 10 deletions jupyter_server/services/contents/filemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import mimetypes
import nbformat

from anyio import run_sync_in_worker_thread
try:
from anyio.to_thread import run_sync
except ImportError:
# fallback on anyio v2 for python version < 3.7
from anyio import run_sync_in_worker_thread as run_sync

from send2trash import send2trash
from tornado import web

Expand Down Expand Up @@ -578,7 +583,7 @@ async def _dir_model(self, path, content=True):
if content:
model['content'] = contents = []
os_dir = self._get_os_path(path)
dir_contents = await run_sync_in_worker_thread(os.listdir, os_dir)
dir_contents = await run_sync(os.listdir, os_dir)
for name in dir_contents:
try:
os_path = os.path.join(os_dir, name)
Expand All @@ -588,7 +593,7 @@ async def _dir_model(self, path, content=True):
continue

try:
st = await run_sync_in_worker_thread(os.lstat, os_path)
st = await run_sync(os.lstat, os_path)
except OSError as e:
# skip over broken symlinks in listing
if e.errno == errno.ENOENT:
Expand Down Expand Up @@ -721,7 +726,7 @@ async def _save_directory(self, os_path, model, path=''):
raise web.HTTPError(400, u'Cannot create hidden directory %r' % os_path)
if not os.path.exists(os_path):
with self.perm_to_403():
await run_sync_in_worker_thread(os.mkdir, os_path)
await run_sync(os.mkdir, os_path)
elif not os.path.isdir(os_path):
raise web.HTTPError(400, u'Not a directory: %s' % (os_path))
else:
Expand Down Expand Up @@ -791,16 +796,16 @@ async def _check_trash(os_path):
# It's a bit more nuanced than this, but until we can better
# distinguish errors from send2trash, assume that we can only trash
# files on the same partition as the home directory.
file_dev = (await run_sync_in_worker_thread(os.stat, os_path)).st_dev
home_dev = (await run_sync_in_worker_thread(os.stat, os.path.expanduser('~'))).st_dev
file_dev = (await run_sync(os.stat, os_path)).st_dev
home_dev = (await run_sync(os.stat, os.path.expanduser('~'))).st_dev
return file_dev == home_dev

async def is_non_empty_dir(os_path):
if os.path.isdir(os_path):
# A directory containing only leftover checkpoints is
# considered empty.
cp_dir = getattr(self.checkpoints, 'checkpoint_dir', None)
dir_contents = set(await run_sync_in_worker_thread(os.listdir, os_path))
dir_contents = set(await run_sync(os.listdir, os_path))
if dir_contents - {cp_dir}:
return True

Expand Down Expand Up @@ -828,11 +833,11 @@ async def is_non_empty_dir(os_path):
raise web.HTTPError(400, u'Directory %s not empty' % os_path)
self.log.debug("Removing directory %s", os_path)
with self.perm_to_403():
await run_sync_in_worker_thread(shutil.rmtree, os_path)
await run_sync(shutil.rmtree, os_path)
else:
self.log.debug("Unlinking file %s", os_path)
with self.perm_to_403():
await run_sync_in_worker_thread(rm, os_path)
await run_sync(rm, os_path)

async def rename_file(self, old_path, new_path):
"""Rename a file."""
Expand All @@ -851,7 +856,7 @@ async def rename_file(self, old_path, new_path):
# Move the file
try:
with self.perm_to_403():
await run_sync_in_worker_thread(shutil.move, old_os_path, new_os_path)
await run_sync(shutil.move, old_os_path, new_os_path)
except web.HTTPError:
raise
except Exception as e:
Expand Down
9 changes: 7 additions & 2 deletions jupyter_server/services/contents/largefilemanager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from anyio import run_sync_in_worker_thread
try:
from anyio.to_thread import run_sync
except ImportError:
# fallback on anyio v2 for python version < 3.7
from anyio import run_sync_in_worker_thread as run_sync

from tornado import web
import base64
import os, io
Expand Down Expand Up @@ -135,6 +140,6 @@ async def _save_large_file(self, os_path, content, format):
if os.path.islink(os_path):
os_path = os.path.join(os.path.dirname(os_path), os.readlink(os_path))
with io.open(os_path, 'ab') as f:
await run_sync_in_worker_thread(f.write, bcontent)
await run_sync(f.write, bcontent)


Loading

0 comments on commit f4ba1ce

Please sign in to comment.