Skip to content

Commit

Permalink
Merge pull request #726 from bnmajor/standalone-viewer-fixes
Browse files Browse the repository at this point in the history
Do not use CellWatcher outside of notebook env
  • Loading branch information
thewtex authored Jan 3, 2024
2 parents a130dfe + 81958d0 commit 5349eeb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
2 changes: 1 addition & 1 deletion itkwidgets/standalone_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def standalone_viewer(url):

def input_dict(viewer_options):
user_input = read_files(viewer_options)
data = build_init_data(user_input)
data = build_init_data(user_input, {})
ui = user_input.get("ui", "reference")
data["config"] = build_config(ui)

Expand Down
35 changes: 21 additions & 14 deletions itkwidgets/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@

_viewer_count = 1
_codecs_registered = False
CellWatcher() # Instantiate the singleton class right away
_cell_watcher = None
if ENVIRONMENT is not Env.HYPHA:
_cell_watcher = CellWatcher() # Instantiate the singleton class right away


class ViewerRPC:
Expand All @@ -49,7 +51,7 @@ class ViewerRPC:
def __init__(
self, ui_collapsed=True, rotate=False, ui="pydata-sphinx", init_data=None, parent=None, **add_data_kwargs
):
global _codecs_registered
global _codecs_registered, _cell_watcher
"""Create a viewer."""
# Register codecs if they haven't been already
if not _codecs_registered and ENVIRONMENT is not Env.HYPHA:
Expand All @@ -64,7 +66,7 @@ def __init__(
self.wid = None
self.parent = parent
if ENVIRONMENT is not Env.JUPYTERLITE:
CellWatcher().add_viewer(self.parent)
_cell_watcher and _cell_watcher.add_viewer(self.parent)
if ENVIRONMENT is not Env.HYPHA:
self.viewer_event = threading.Event()
self.data_event = threading.Event()
Expand All @@ -74,7 +76,7 @@ async def setup(self):

async def run(self, ctx):
"""ImJoy plugin setup function."""
global _viewer_count
global _viewer_count, _cell_watcher
ui = self._init_viewer_kwargs.get("ui", None)
config = build_config(ui)

Expand Down Expand Up @@ -102,7 +104,7 @@ async def run(self, ctx):
)
if not defer_for_data_render(self.init_data):
# Once the viewer has been created any queued requests can be run
CellWatcher().update_viewer_status(self.parent, True)
_cell_watcher.update_viewer_status(self.parent, True)
asyncio.get_running_loop().call_soon_threadsafe(self.viewer_event.set)

# Wait and then update the screenshot in case rendered level changed
Expand Down Expand Up @@ -139,14 +141,16 @@ def update_screenshot(self, base64_image):
self.img.display(html)

def update_viewer_status(self):
if not CellWatcher().viewer_ready(self.parent):
CellWatcher().update_viewer_status(self.parent, True)
global _cell_watcher
if not _cell_watcher.viewer_ready(self.parent):
_cell_watcher.update_viewer_status(self.parent, True)

def set_event(self, event_data):
if not self.data_event.is_set():
# Once the data has been set the deferred queue requests can be run
asyncio.get_running_loop().call_soon_threadsafe(self.data_event.set)
self.update_viewer_status()
if ENVIRONMENT is not Env.HYPHA:
self.update_viewer_status()


class Viewer:
Expand All @@ -166,7 +170,6 @@ def __init__(
self.viewer_rpc = ViewerRPC(
ui_collapsed=ui_collapsed, rotate=rotate, ui=ui, init_data=data, parent=self.name, **add_data_kwargs
)
self.cw = CellWatcher()
if ENVIRONMENT is not Env.JUPYTERLITE:
self._setup_queueing()
api.export(self.viewer_rpc)
Expand Down Expand Up @@ -219,9 +222,10 @@ def queue_worker(self):
loop.run_until_complete(task)

def call_getter(self, future):
global _cell_watcher
name = uuid.uuid4()
CellWatcher().results[name] = future
future.add_done_callback(functools.partial(CellWatcher()._callback, name))
_cell_watcher.results[name] = future
future.add_done_callback(functools.partial(_cell_watcher._callback, name))

def queue_request(self, method, *args, **kwargs):
if (
Expand Down Expand Up @@ -275,6 +279,7 @@ async def get_cropping_planes(self):

@fetch_value
def set_image(self, image: Image, name: str = 'Image'):
global _cell_watcher
render_type = _detect_render_type(image, 'image')
if render_type is RenderType.IMAGE:
image = _get_viewer_image(image, label=False)
Expand All @@ -286,7 +291,7 @@ def set_image(self, image: Image, name: str = 'Image'):
svc.set_label_or_image('image')
else:
self.queue_request('setImage', image, name)
CellWatcher().update_viewer_status(self.name, False)
_cell_watcher and _cell_watcher.update_viewer_status(self.name, False)
elif render_type is RenderType.POINT_SET:
image = _get_viewer_point_set(image)
self.queue_request('setPointSets', image)
Expand Down Expand Up @@ -531,6 +536,7 @@ async def get_roi_slice(self, scale: int = -1):

@fetch_value
def compare_images(self, fixed_image: Union[str, Image], moving_image: Union[str, Image], method: str = None, image_mix: float = None, checkerboard: bool = None, pattern: Union[Tuple[int, int], Tuple[int, int, int]] = None, swap_image_order: bool = None):
global _cell_watcher
# image args may be image name or image object
fixed_name = 'Fixed'
if isinstance(fixed_image, str):
Expand All @@ -555,10 +561,11 @@ def compare_images(self, fixed_image: Union[str, Image], moving_image: Union[str
if swap_image_order is not None:
options['swapImageOrder'] = swap_image_order
self.queue_request('compareImages', fixed_name, moving_name, options)
CellWatcher().update_viewer_status(self.name, False)
_cell_watcher and _cell_watcher.update_viewer_status(self.name, False)

@fetch_value
def set_label_image(self, label_image: Image):
global _cell_watcher
render_type = _detect_render_type(label_image, 'image')
if render_type is RenderType.IMAGE:
label_image = _get_viewer_image(label_image, label=True)
Expand All @@ -570,7 +577,7 @@ def set_label_image(self, label_image: Image):
svc.set_label_or_image('label_image')
else:
self.queue_request('setLabelImage', label_image)
CellWatcher().update_viewer_status(self.name, False)
_cell_watcher and _cell_watcher.update_viewer_status(self.name, False)
elif render_type is RenderType.POINT_SET:
label_image = _get_viewer_point_set(label_image)
self.queue_request('setPointSets', label_image)
Expand Down

0 comments on commit 5349eeb

Please sign in to comment.