diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 930db684a..78de81565 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -2009,8 +2009,10 @@ def try_and_skip_if_error(func, *args, **kwargs): # Load new skeleton filename = params["filename"] new_skeleton = OpenSkeleton.load_skeleton(filename) - if new_skeleton.description == None: - new_skeleton.description = f"Custom Skeleton loaded from {filename}" + + # Description and preview image only used for template skeletons + new_skeleton.description = None + new_skeleton.preview_image = None context.state["skeleton_description"] = new_skeleton.description context.state["skeleton_preview_image"] = new_skeleton.preview_image diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 11e618ebb..eca393b8e 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -95,13 +95,15 @@ class Skeleton: _skeleton_idx: An index variable used to give skeletons a default name that should be unique across all skeletons. preview_image: A byte string containing an encoded preview image for the - skeleton. - description: A text description of the skeleton. Used mostly for presets. + skeleton. Used only for templates. + description: A text description of the skeleton. Used only for templates. + _is_template: Whether this skeleton is a template. Used only for templates. """ _skeleton_idx = count(0) preview_image: Optional[bytes] = None description: Optional[str] = None + _is_template: bool = False def __init__(self, name: str = None): """Initialize an empty skeleton object. @@ -176,6 +178,32 @@ def dict_match(dict1, dict2): return True + @property + def is_template(self) -> bool: + """Return whether this skeleton is a template. + + If is_template is True, then the preview image and description are saved. + If is_template is False, then the preview image and description are not saved. + + Only provided template skeletons are considered templates. To save a new + template skeleton, change this to True before saving. + """ + return self._is_template + + @is_template.setter + def is_template(self, value: bool): + """Set whether this skeleton is a template.""" + + self._is_template = False + if value and ((self.preview_image is None) or (self.description is None)): + raise ValueError( + "For a skeleton to be a template, it must have both a preview image " + "and description. Checkout `generate_skeleton_preview_image` to " + "generate a preview image." + ) + + self._is_template = value + @property def is_arborescence(self) -> bool: """Return whether this skeleton graph forms an arborescence.""" @@ -956,8 +984,7 @@ def from_names_and_edge_inds( return skeleton def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: - """ - Convert the :class:`Skeleton` to a JSON representation. + """Convert the :class:`Skeleton` to a JSON representation. Args: node_to_idx: optional dict which maps :class:`Node`sto index @@ -981,12 +1008,22 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: indexed_node_graph = self._graph # Encode to JSON - dicts = { - "nx_graph": json_graph.node_link_data(indexed_node_graph), - "description": self.description, - "preview_image": self.preview_image, - } - json_str = jsonpickle.encode(dicts) + graph = json_graph.node_link_data(indexed_node_graph) + + # SLEAP v1.3.0 added `description` and `preview_image` to `Skeleton`, but saving + # these fields breaks data format compatibility. Currently, these are only + # added in our custom template skeletons. To ensure backwards data format + # compatibilty of user data, we only save these fields if they are not None. + if self.is_template: + data = { + "nx_graph": graph, + "description": self.description, + "preview_image": self.preview_image, + } + else: + data = graph + + json_str = jsonpickle.encode(data) return json_str @@ -1020,8 +1057,7 @@ def save_json(self, filename: str, node_to_idx: Optional[Dict[Node, int]] = None def from_json( cls, json_str: str, idx_to_node: Dict[int, Node] = None ) -> "Skeleton": - """ - Instantiate :class:`Skeleton` from JSON string. + """Instantiate :class:`Skeleton` from JSON string. Args: json_str: The JSON encoded Skeleton. @@ -1036,9 +1072,8 @@ def from_json( An instance of the `Skeleton` object decoded from the JSON. """ dicts = jsonpickle.decode(json_str) - if "nx_graph" not in dicts: - dicts = {"nx_graph": dicts, "description": None, "preview_image": None} - graph = json_graph.node_link_graph(dicts["nx_graph"]) + nx_graph = dicts.get("nx_graph", dicts) + graph = json_graph.node_link_graph(nx_graph) # Replace graph node indices with corresponding nodes from node_map if idx_to_node is not None: @@ -1046,8 +1081,8 @@ def from_json( skeleton = Skeleton() skeleton._graph = graph - skeleton.description = dicts["description"] - skeleton.preview_image = dicts["preview_image"] + skeleton.description = dicts.get("description", None) + skeleton.preview_image = dicts.get("preview_image", None) return skeleton @@ -1055,8 +1090,7 @@ def from_json( def load_json( cls, filename: str, idx_to_node: Dict[int, Node] = None ) -> "Skeleton": - """ - Load a skeleton from a JSON file. + """Load a skeleton from a JSON file. This method will load the Skeleton from JSON file saved with; :meth:`~Skeleton.save_json` diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index e409f3bbe..1f7c3a853 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -1,6 +1,7 @@ import os import copy +import jsonpickle import pytest from sleap.skeleton import Skeleton @@ -182,12 +183,37 @@ def test_symmetry(): ] -def test_json(skeleton, tmpdir): - """ - Test saving and loading a Skeleton object in JSON. - """ +def test_json(skeleton: Skeleton, tmpdir): + """Test saving and loading a Skeleton object in JSON.""" JSON_TEST_FILENAME = os.path.join(tmpdir, "skeleton.json") + # Test that `to_json` does not save unused `None` fields (to ensure backwards data + # format compatibility) + skeleton.description = ( + "Test that description is not saved when given (if is_template is False)." + ) + assert skeleton.is_template == False + json_str = skeleton.to_json() + json_dict = jsonpickle.decode(json_str) + json_dict_keys = list(json_dict.keys()) + assert "nx_graph" not in json_dict_keys + assert "preview_image" not in json_dict_keys + assert "description" not in json_dict_keys + + # Test that `is_template` can only be set to True + # when has both `description` and `preview_image` + with pytest.raises(ValueError): + skeleton.is_template = True + assert skeleton.is_template == False + + skeleton._is_template = True + json_str = skeleton.to_json() + json_dict = jsonpickle.decode(json_str) + json_dict_keys = list(json_dict.keys()) + assert "nx_graph" in json_dict_keys + assert "preview_image" in json_dict_keys + assert "description" in json_dict_keys + # Save it to a JSON filename skeleton.save_json(JSON_TEST_FILENAME)