From aabfd2f523ed9f2f7c8a43a81ee687c888180ada Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jord=C3=A3o=20Bragantini?= Date: Tue, 13 Jun 2023 16:22:28 -0700 Subject: [PATCH] Universal reader output base classes (#132) * WIP: universal reader output base classes * improved basefov getitem typing and added more methods to basefovcollection * Update iohub/dataset_base.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> * Update iohub/dataset_base.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> * Update iohub/dataset_base.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> * fixing typing from contributions * Update iohub/dataset_base.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> * add helper functions to find and pad missing axes * added helper functions tests * wip: making basefovcollection a mapping * improved FOVCollection * split FOVCollection into BaseFOVCollection and FOVDict * renamed files to fov.py * talon review * fixing tests * fixing `__exit__` return * Update iohub/fov.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> * fixing length --------- Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- iohub/fov.py | 225 ++++++++++++++++++++++++++++++++++++++++++ tests/fov/test_fov.py | 93 +++++++++++++++++ 2 files changed, 318 insertions(+) create mode 100644 iohub/fov.py create mode 100644 tests/fov/test_fov.py diff --git a/iohub/fov.py b/iohub/fov.py new file mode 100644 index 00000000..d91a21f8 --- /dev/null +++ b/iohub/fov.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Mapping +from pathlib import Path +from types import TracebackType +from typing import Any, Iterable, Optional, Type, Union + +import numpy as np +from numpy.typing import ArrayLike + +_AXES_PREFIX = ["T", "C", "Z", "Y", "X"] + + +class BaseFOV(ABC): + @property + @abstractmethod + def root(self) -> Path: + raise NotImplementedError + + @property + @abstractmethod + def axes_names(self) -> list[str]: + raise NotImplementedError + + @property + @abstractmethod + def channel_names(self) -> list[str]: + raise NotImplementedError + + def channel_index(self, key: str) -> int: + """Return index of given channel.""" + return self.channels.index(key) + + def _missing_axes(self) -> list[int]: + """Return sorted indices of missing axes.""" + if len(self.axes_names) == 5: + return [] + + elif len(self.axes_names) > 5: + raise ValueError( + f"{self.__name__} does not support more than 5 axes. " + f"Found {len(self.axes_names)}" + ) + + axes = set(ax[:1].upper() for ax in self.axes_names) + + missing = [] + for i, ax in enumerate(_AXES_PREFIX): + if ax not in axes: + missing.append(i) + + return missing + + def _pad_missing_axes( + self, + seq: Union[list[Any], tuple[Any]], + value: Any, + ) -> Union[list[Any], tuple[Any]]: + """Pads ``seq`` with ``value`` in the missing axes positions.""" + + if isinstance(seq, tuple): + is_tuple = True + seq = list(seq) + else: + is_tuple = False + + for i in self._missing_axes(): + seq.insert(i, value) + + if is_tuple: + seq = tuple(seq) + + if len(seq) != len(_AXES_PREFIX): + raise RuntimeError( + f"Failed to pad raw axes {self.axes_names} to {_AXES_PREFIX}" + ) + + return seq + + @property + @abstractmethod + def shape(self) -> tuple[int, int, int, int, int]: + raise NotImplementedError + + @abstractmethod + def __getitem__( + self, + key: Union[int, slice, tuple[Union[int, slice], ...]], + ) -> ArrayLike: + """ + Returned object must support the ``__array__`` interface, + so that ``np.asarray(...)`` will work. + """ + raise NotImplementedError + + @property + def ndim(self) -> int: + return 5 + + @property + @abstractmethod + def dtype(self) -> np.dtype: + raise NotImplementedError + + @property + @abstractmethod + def zyx_scale(self) -> tuple[float, float, float]: + """Helper function for FOV spatial scale (micrometer).""" + raise NotImplementedError + + @property + @abstractmethod + def t_scale(self) -> float: + """Helper function for FOV time scale (seconds).""" + raise NotImplementedError + + def __eq__(self, other: BaseFOV) -> bool: + if not isinstance(other, BaseFOV): + return False + return self.root.absolute() == other.root.absolute() + + +class BaseFOVMapping(Mapping): + @abstractmethod + def __enter__(self) -> BaseFOVMapping: + """Open the underlying file and return self.""" + raise NotImplementedError + + @abstractmethod + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + """Close the files.""" + raise NotImplementedError + + @abstractmethod + def __contains__(self, position_key: str) -> bool: + """Check if a position is present in the collection.""" + raise NotImplementedError + + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def __getitem__(self, position_key: str) -> BaseFOV: + """FOV key position to FOV object.""" + raise NotImplementedError + + @abstractmethod + def __iter__(self) -> Iterable[tuple[str, BaseFOV]]: + """Iterates over pairs of keys and FOVs.""" + raise NotImplementedError + + +class FOVDict(BaseFOVMapping): + """ + Basic implementation of a mapping of strings to BaseFOVs. + """ + + def __init__( + self, + data_dict: Optional[dict[str, BaseFOV]] = None, + **kwargs, + ) -> None: + super().__init__() + self._data = {} + + if data_dict is not None: + for key, fov in data_dict.items(): + self._safe_insert(key, fov) + + for key, fov in kwargs.items(): + self._safe_insert(key, fov) + + def _safe_insert(self, key: str, value: BaseFOV) -> None: + """Checks if types are correct and key is unique.""" + if not isinstance(key, str): + raise TypeError( + f"{self.__class__.__name__} key must be str. " + f"Found {key} with type {type(key)}" + ) + + if not isinstance(value, BaseFOV): + raise TypeError( + f"{self.__class__.__name__} value must subclass BaseFOV. " + f"Found {key} with value type {type(value)}" + ) + + if key in self: + raise KeyError(f"{key} already exists.") + + self._data[key] = value + + def __contains__(self, position_key: str) -> bool: + """Checks if position_key already exists.""" + return position_key in self._data + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, position_key: str) -> BaseFOV: + """FOV key position to FOV object.""" + return self._data[position_key] + + def __iter__(self) -> Iterable[tuple[str, BaseFOV]]: + """Iterates over pairs of keys and FOVs.""" + return self._data.items() + + def __enter__(self) -> FOVDict: + """Open the underlying file and return self.""" + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + """Close the files.""" + return False diff --git a/tests/fov/test_fov.py b/tests/fov/test_fov.py new file mode 100644 index 00000000..57e19140 --- /dev/null +++ b/tests/fov/test_fov.py @@ -0,0 +1,93 @@ +from pathlib import Path +from typing import Any + +import pytest + +from iohub.fov import BaseFOV, FOVDict + + +class FOV(BaseFOV): + def __init__(self, axes: list[str]) -> None: + super().__init__() + self._axes = axes + + @property + def root(self) -> Path: + return Path() + + @property + def axes_names(self) -> list[str]: + return self._axes + + @property + def channel_names(self) -> list[str]: + return [] + + def __getitem__(self, key: Any) -> Any: + pass + + @property + def dtype(self) -> Any: + pass + + @property + def shape(self) -> Any: + pass + + @property + def zyx_scale(self) -> tuple[float, float, float]: + return (1.0,) * 3 + + @property + def t_scale(self) -> float: + raise 1.0 + + +@pytest.mark.parametrize( + "axes,missing", + [ + (["T", "C", "Z", "Y", "X"], []), + (["time", "z", "y", "x"], [1]), + (["channels", "Y", "X"], [0, 2]), + ], +) +def test_missing_axes(axes: list[str], missing: list[int]) -> None: + fov = FOV(axes) + assert fov._missing_axes() == missing + + shape = (10,) * len(axes) + padded_shape = fov._pad_missing_axes(shape, 1) + assert len(padded_shape) == 5 + + for i, s in enumerate(padded_shape): + if i in missing: + assert s == 1 + else: + assert s == 10 + + +def test_fov_dict() -> None: + + good_collection = FOVDict( + { + "488": FOV(["y", "x"]), + "561": FOV(["y", "x"]), + }, + mask=FOV(["c", "x", "y"]), + ) + + assert len(good_collection) == 3 + assert "488" in good_collection + assert good_collection["mask"] is not None + + with pytest.raises(TypeError): + del good_collection["561"] + + with pytest.raises(TypeError): + good_collection["segmentation"] = FOV(["x", "y"]) + + with pytest.raises(TypeError): + FOVDict({488: FOV(["y", "x"])}) + + with pytest.raises(TypeError): + FOVDict(mask=[1, 2, 3])