Skip to content

Commit

Permalink
UUID as traitlets for threading related widgets (#375)
Browse files Browse the repository at this point in the history
AiiDA is not thread-safe, for the node queried from DB, it will live at the end of DB session. 
In the AiiDAlab widget, we use threading for the widget since we don't want to block the whole notebook in some long running operation. 
However, the some traitlets in the widget implementation are AiiDA node, which lead to the issue that when the threading is closed the session also finalized thereafter the node is expired and invalid to be used.
For more details of discussion about theread-local problem: check aiidateam/aiida-core#5765

Here, we break the API by changing all the traitlets that use AiiDA node to the static UUID, only loading the node and consuming it in the local scope. 
The PR involves all the widgets that explicitly using threading for the AiiDA node. There are possibilities that the other thread trigger the traitlets change and having AiiDA node in the extra threading, which also cause the exactly same issue. One of the example is the `ComputationalResourceWidget`, which has AiiDA code node as traitlets but has no threading. The issue happened when it is called from QeApp by additional threading from code setup widget.
In the future, we will change all the traitlets to UUID.
  • Loading branch information
unkcpz authored Nov 16, 2022
1 parent 5d64006 commit 2ec0db9
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 105 deletions.
84 changes: 32 additions & 52 deletions aiidalab_widgets_base/computational_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import threading
from copy import copy
from pathlib import Path
from uuid import UUID

import ipywidgets as ipw
import pexpect
Expand All @@ -27,12 +28,11 @@ class ComputationalResourcesWidget(ipw.VBox):
"""Code selection widget.
Attributes:
value(Unicode or Code): Trait that points to the selected Code instance.
It can be set either to an AiiDA Code instance or to a code label (will automatically
be replaced by the corresponding Code instance). It is linked to the 'value' trait of
the `self.code_select_dropdown` widget.
value(code UUID): Trait that points to the selected UUID of the code instance.
It can be set either to an AiiDA code UUID or to a code label.
It is linked to the `value` trait of the `self.code_select_dropdown` widget.
codes(Dict): Trait that contains a dictionary (label => Code instance) for all
codes(Dict): Trait that contains a dictionary (label => Code UUID) for all
codes found in the AiiDA database for the selected plugin. It is linked
to the 'options' trait of the `self.code_select_dropdown` widget.
Expand All @@ -42,9 +42,7 @@ class ComputationalResourcesWidget(ipw.VBox):
computers.
"""

value = traitlets.Union(
[traitlets.Unicode(), traitlets.Instance(orm.Code)], allow_none=True
)
value = traitlets.Unicode(allow_none=True)
codes = traitlets.Dict(allow_none=True)
allow_hidden_codes = traitlets.Bool(False)
allow_disabled_computers = traitlets.Bool(False)
Expand Down Expand Up @@ -164,7 +162,7 @@ def _get_codes(self):
user = orm.User.collection.get_default()

return {
self._full_code_label(c[0]): c[0]
self._full_code_label(c[0]): c[0].uuid
for c in orm.QueryBuilder()
.append(orm.Code, filters={"attributes.input_plugin": self.input_plugin})
.all()
Expand Down Expand Up @@ -197,29 +195,21 @@ def refresh(self, _=None):

@traitlets.validate("value")
def _validate_value(self, change):
"""If code is provided, set it as it is. If code's label is provided,
select the code and set it."""
code = change["value"]
"""Check if the code is valid in DB"""
code_uuid = change["value"]
self.output.value = ""

# If code None, set value to None.
if code is None:
if code_uuid is None:
return None

if isinstance(code, str): # Check code by label.
if code in self.codes:
return self.codes[code]
self.output.value = f"""No code named '<span style="color:red">{code}</span>'
found in the AiiDA database."""
elif isinstance(code, orm.Code): # Check code by value.
label = self._full_code_label(code)
if label in self.codes:
return code
self.output.value = f"""The code instance '<span style="color:red">{code}</span>'
supplied was not found in the AiiDA database."""

# This place will never be reached, because the trait's type is checked.
return None
try:
_ = UUID(code_uuid, version=4)
except ValueError:
self.output.value = f"""'<span style="color:red">{code_uuid}</span>'
is not a valid UUID."""
else:
return code_uuid

def _setup_new_code(self, _=None):
with self._setup_new_code_output:
Expand Down Expand Up @@ -1149,21 +1139,16 @@ class ComputerDropdownWidget(ipw.VBox):
"""Widget to select a configured computer.
Attributes:
selected_computer(Unicode or Computer): Trait that points to the selected Computer instance.
It can be set either to an AiiDA Computer instance or to a computer label (will
automatically be replaced by the corresponding Computer instance). It is linked to the
value(computer UUID): Trait that points to the selected Computer instance.
It can be set to an AiiDA Computer UUID. It is linked to the
'value' trait of `self._dropdown` widget.
computers(Dict): Trait that contains a dictionary (label => Computer instance) for all
computers(Dict): Trait that contains a dictionary (label => Computer UUID) for all
computers found in the AiiDA database. It is linked to the 'options' trait of
`self._dropdown` widget.
allow_select_disabled(Bool): Trait that defines whether to show disabled computers.
"""

value = traitlets.Union(
[traitlets.Unicode(), traitlets.Instance(orm.Computer)], allow_none=True
)
value = traitlets.Unicode(allow_none=True)
computers = traitlets.Dict(allow_none=True)
allow_select_disabled = traitlets.Bool(False)

Expand Down Expand Up @@ -1207,7 +1192,7 @@ def _get_computers(self):
user = orm.User.collection.get_default()

return {
c[0].label: c[0]
c[0].label: c[0].uuid
for c in orm.QueryBuilder().append(orm.Computer).all()
if c[0].is_user_configured(user)
and (self.allow_select_disabled or c[0].is_user_enabled(user))
Expand All @@ -1228,21 +1213,16 @@ def refresh(self, _=None):

@traitlets.validate("value")
def _validate_value(self, change):
"""Select computer either by label or by class instance."""
computer = change["value"]
"""Select computer by computer UUID."""
computer_uuid = change["value"]
self.output.value = ""
if not computer:
return None
if isinstance(computer, str):
if computer in self.computers:
return self.computers[computer]
self.output.value = f"""Computer instance '<span style="color:red">{computer}</span>'
is not configured in your AiiDA profile."""
if not computer_uuid:
return None

if isinstance(computer, orm.Computer):
if computer.label in self.computers:
return computer
self.output.value = f"""Computer instance '<span style="color:red">{computer.label}</span>'
is not configured in your AiiDA profile."""
return None
try:
_ = UUID(computer_uuid, version=4)
except ValueError:
self.output.value = f"""'<span style="color:red">{computer_uuid}</span>'
is not a valid UUID."""
else:
return computer_uuid
116 changes: 63 additions & 53 deletions aiidalab_widgets_base/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from inspect import isclass, signature
from threading import Event, Lock, Thread
from time import sleep
from uuid import UUID

# External imports
import ipywidgets as ipw
Expand All @@ -20,11 +21,7 @@
get_workchain_report,
)
from aiida.cmdline.utils.query.calculation import CalculationQueryBuilder
from aiida.common.exceptions import (
MultipleObjectsError,
NotExistent,
NotExistentAttributeError,
)
from aiida.common.exceptions import NotExistentAttributeError
from aiida.engine import Process, ProcessBuilder, submit
from aiida.orm import (
CalcFunctionNode,
Expand All @@ -36,7 +33,7 @@
load_node,
)
from IPython.display import HTML, Javascript, clear_output, display
from traitlets import Instance, Int, List, Unicode, Union, default, observe, validate
from traitlets import Instance, Int, List, Unicode, default, observe, validate

from .nodes import NodesTreeWidget

Expand Down Expand Up @@ -545,10 +542,10 @@ class ProcessListWidget(ipw.VBox):
past_days (int): Sumulations that were submitted in the last `past_days`.
incoming_node (int, str, Node): Trait that takes node id or uuid and returns the node that must
incoming_node (str): Trait that takes node uuid that must
be among the input nodes of the process of interest.
outgoing_node (int, str, Node): Trait that takes node id or uuid and returns the node that must
outgoing_node (str): Trait that takes node uuid that must
be among the output nodes of the process of interest.
process_states (list): List of allowed process states.
Expand All @@ -560,8 +557,8 @@ class ProcessListWidget(ipw.VBox):
"""

past_days = Int(7)
incoming_node = Union([Int(), Unicode(), Instance(Node)], allow_none=True)
outgoing_node = Union([Int(), Unicode(), Instance(Node)], allow_none=True)
incoming_node = Unicode(allow_none=True)
outgoing_node = Unicode(allow_none=True)
process_states = List()
process_label = Unicode(allow_none=True)
description_contains = Unicode(allow_none=True)
Expand Down Expand Up @@ -605,10 +602,16 @@ def update(self, _=None):
)
relationships = {}
if self.incoming_node:
relationships = {**relationships, **{"with_outgoing": self.incoming_node}}
relationships = {
**relationships,
**{"with_outgoing": load_node(self.incoming_node)},
}

if self.outgoing_node:
relationships = {**relationships, **{"with_incoming": self.outgoing_node}}
relationships = {
**relationships,
**{"with_incoming": load_node(self.outgoing_node)},
}

query_set = builder.get_query_set(
filters=filters,
Expand Down Expand Up @@ -645,23 +648,27 @@ def update(self, _=None):

@validate("incoming_node")
def _validate_incoming_node(self, provided):
"""Validate incoming node. The function load_node takes care of managing ids and uuids."""
if provided["value"]:
try:
return load_node(provided["value"])
except (MultipleObjectsError, NotExistent):
return None
return None
"""Validate incoming node."""
node_uuid = provided["value"]
try:
_ = UUID(node_uuid, version=4)
except ValueError:
self.output.value = f"""'<span style="color:red">{node_uuid}</span>'
is not a valid UUID."""
else:
return node_uuid

@validate("outgoing_node")
def _validate_outgoing_node(self, provided):
"""Validate outgoing node. The function load_node takes care of managing ids and uuids."""
if provided["value"]:
try:
return load_node(provided["value"])
except (MultipleObjectsError, NotExistent):
return None
return None
node_uuid = provided["value"]
try:
_ = UUID(node_uuid, version=4)
except ValueError:
self.output.value = f"""'<span style="color:red">{node_uuid}</span>'
is not a valid UUID."""
else:
return node_uuid

@default("process_label")
def _default_process_label(self):
Expand All @@ -688,7 +695,7 @@ def start_autoupdate(self, update_interval=10):
class ProcessMonitor(traitlets.HasTraits):
"""Monitor a process and execute callback functions at specified intervals."""

process = traitlets.Instance(ProcessNode, allow_none=True)
value = Unicode(allow_none=True)

def __init__(self, callbacks=None, on_sealed=None, timeout=None, **kwargs):
self.callbacks = [] if callbacks is None else list(callbacks)
Expand All @@ -701,29 +708,31 @@ def __init__(self, callbacks=None, on_sealed=None, timeout=None, **kwargs):

super().__init__(**kwargs)

@traitlets.observe("process")
@traitlets.observe("value")
def _observe_process(self, change):
process = change["new"]
if process is None or process.id != getattr(change["old"], "id", None):
with self.hold_trait_notifications():
with self._monitor_thread_lock:
# stop thread (if running)
if self._monitor_thread is not None:
self._monitor_thread_stop.set()
self._monitor_thread.join()

# start monitor thread
if process is not None:
self._monitor_thread_stop.clear()
process_id = getattr(process, "id", None)
self._monitor_thread = Thread(
target=self._monitor_process, args=(process_id,)
)
self._monitor_thread.start()

def _monitor_process(self, process_id):
assert process_id is not None
process = load_node(process_id)
"""When the value (process uuid) is changed, stop the previous
monitor if exist. Start a new one in thread."""
process_uuid = change["new"]

# stop thread (if running)
if self._monitor_thread is not None:
with self._monitor_thread_lock:
self._monitor_thread_stop.set()
self._monitor_thread.join()

if process_uuid is None:
return

with self._monitor_thread_lock:
self._monitor_thread_stop.clear()
self._monitor_thread = Thread(
target=self._monitor_process, args=(process_uuid,)
)
self._monitor_thread.start()

def _monitor_process(self, process_uuid):
assert process_uuid is not None
process = load_node(process_uuid)

disabled_funcs = set()

Expand All @@ -735,7 +744,7 @@ def _run(funcs):

try:
if len(signature(func).parameters) > 0:
func(process_id)
func(process_uuid)
else:
func()
except Exception as error:
Expand Down Expand Up @@ -765,7 +774,7 @@ def join(self):
class ProcessNodesTreeWidget(ipw.VBox):
"""A tree widget for the structured representation of a process graph."""

process = traitlets.Instance(ProcessNode, allow_none=True)
value = traitlets.Unicode(allow_none=True)
selected_nodes = traitlets.Tuple(read_only=True).tag(trait=traitlets.Instance(Node))

def __init__(self, title="Process Tree", **kwargs):
Expand All @@ -781,10 +790,11 @@ def _observe_tree_selected_nodes(self, change):
def update(self, _=None):
self._tree.update()

@traitlets.observe("process")
@traitlets.observe("value")
def _observe_process(self, change):
process = change["new"]
if process:
process_uuid = change["new"]
if process_uuid:
process = load_node(process_uuid)
self._tree.nodes = [process]
self._tree.find_node(process.pk).selected = True
else:
Expand Down

0 comments on commit 2ec0db9

Please sign in to comment.