-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Universal reader output base classes (#132)
* WIP: universal reader output base classes * improved basefov getitem typing and added more methods to basefovcollection * Update iohub/dataset_base.py Co-authored-by: Ziwen Liu <[email protected]> * Update iohub/dataset_base.py Co-authored-by: Ziwen Liu <[email protected]> * Update iohub/dataset_base.py Co-authored-by: Ziwen Liu <[email protected]> * fixing typing from contributions * Update iohub/dataset_base.py Co-authored-by: Ziwen Liu <[email protected]> * add helper functions to find and pad missing axes * added helper functions tests * wip: making basefovcollection a mapping * improved FOVCollection * split FOVCollection into BaseFOVCollection and FOVDict * renamed files to fov.py * talon review * fixing tests * fixing `__exit__` return * Update iohub/fov.py Co-authored-by: Ziwen Liu <[email protected]> * fixing length --------- Co-authored-by: Ziwen Liu <[email protected]>
- Loading branch information
Showing
2 changed files
with
318 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from collections.abc import Mapping | ||
from pathlib import Path | ||
from types import TracebackType | ||
from typing import Any, Iterable, Optional, Type, Union | ||
|
||
import numpy as np | ||
from numpy.typing import ArrayLike | ||
|
||
_AXES_PREFIX = ["T", "C", "Z", "Y", "X"] | ||
|
||
|
||
class BaseFOV(ABC): | ||
@property | ||
@abstractmethod | ||
def root(self) -> Path: | ||
raise NotImplementedError | ||
|
||
@property | ||
@abstractmethod | ||
def axes_names(self) -> list[str]: | ||
raise NotImplementedError | ||
|
||
@property | ||
@abstractmethod | ||
def channel_names(self) -> list[str]: | ||
raise NotImplementedError | ||
|
||
def channel_index(self, key: str) -> int: | ||
"""Return index of given channel.""" | ||
return self.channels.index(key) | ||
|
||
def _missing_axes(self) -> list[int]: | ||
"""Return sorted indices of missing axes.""" | ||
if len(self.axes_names) == 5: | ||
return [] | ||
|
||
elif len(self.axes_names) > 5: | ||
raise ValueError( | ||
f"{self.__name__} does not support more than 5 axes. " | ||
f"Found {len(self.axes_names)}" | ||
) | ||
|
||
axes = set(ax[:1].upper() for ax in self.axes_names) | ||
|
||
missing = [] | ||
for i, ax in enumerate(_AXES_PREFIX): | ||
if ax not in axes: | ||
missing.append(i) | ||
|
||
return missing | ||
|
||
def _pad_missing_axes( | ||
self, | ||
seq: Union[list[Any], tuple[Any]], | ||
value: Any, | ||
) -> Union[list[Any], tuple[Any]]: | ||
"""Pads ``seq`` with ``value`` in the missing axes positions.""" | ||
|
||
if isinstance(seq, tuple): | ||
is_tuple = True | ||
seq = list(seq) | ||
else: | ||
is_tuple = False | ||
|
||
for i in self._missing_axes(): | ||
seq.insert(i, value) | ||
|
||
if is_tuple: | ||
seq = tuple(seq) | ||
|
||
if len(seq) != len(_AXES_PREFIX): | ||
raise RuntimeError( | ||
f"Failed to pad raw axes {self.axes_names} to {_AXES_PREFIX}" | ||
) | ||
|
||
return seq | ||
|
||
@property | ||
@abstractmethod | ||
def shape(self) -> tuple[int, int, int, int, int]: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def __getitem__( | ||
self, | ||
key: Union[int, slice, tuple[Union[int, slice], ...]], | ||
) -> ArrayLike: | ||
""" | ||
Returned object must support the ``__array__`` interface, | ||
so that ``np.asarray(...)`` will work. | ||
""" | ||
raise NotImplementedError | ||
|
||
@property | ||
def ndim(self) -> int: | ||
return 5 | ||
|
||
@property | ||
@abstractmethod | ||
def dtype(self) -> np.dtype: | ||
raise NotImplementedError | ||
|
||
@property | ||
@abstractmethod | ||
def zyx_scale(self) -> tuple[float, float, float]: | ||
"""Helper function for FOV spatial scale (micrometer).""" | ||
raise NotImplementedError | ||
|
||
@property | ||
@abstractmethod | ||
def t_scale(self) -> float: | ||
"""Helper function for FOV time scale (seconds).""" | ||
raise NotImplementedError | ||
|
||
def __eq__(self, other: BaseFOV) -> bool: | ||
if not isinstance(other, BaseFOV): | ||
return False | ||
return self.root.absolute() == other.root.absolute() | ||
|
||
|
||
class BaseFOVMapping(Mapping): | ||
@abstractmethod | ||
def __enter__(self) -> BaseFOVMapping: | ||
"""Open the underlying file and return self.""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def __exit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc_val: Optional[BaseException], | ||
exc_tb: Optional[TracebackType], | ||
) -> bool: | ||
"""Close the files.""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def __contains__(self, position_key: str) -> bool: | ||
"""Check if a position is present in the collection.""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def __len__(self) -> int: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def __getitem__(self, position_key: str) -> BaseFOV: | ||
"""FOV key position to FOV object.""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def __iter__(self) -> Iterable[tuple[str, BaseFOV]]: | ||
"""Iterates over pairs of keys and FOVs.""" | ||
raise NotImplementedError | ||
|
||
|
||
class FOVDict(BaseFOVMapping): | ||
""" | ||
Basic implementation of a mapping of strings to BaseFOVs. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_dict: Optional[dict[str, BaseFOV]] = None, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
self._data = {} | ||
|
||
if data_dict is not None: | ||
for key, fov in data_dict.items(): | ||
self._safe_insert(key, fov) | ||
|
||
for key, fov in kwargs.items(): | ||
self._safe_insert(key, fov) | ||
|
||
def _safe_insert(self, key: str, value: BaseFOV) -> None: | ||
"""Checks if types are correct and key is unique.""" | ||
if not isinstance(key, str): | ||
raise TypeError( | ||
f"{self.__class__.__name__} key must be str. " | ||
f"Found {key} with type {type(key)}" | ||
) | ||
|
||
if not isinstance(value, BaseFOV): | ||
raise TypeError( | ||
f"{self.__class__.__name__} value must subclass BaseFOV. " | ||
f"Found {key} with value type {type(value)}" | ||
) | ||
|
||
if key in self: | ||
raise KeyError(f"{key} already exists.") | ||
|
||
self._data[key] = value | ||
|
||
def __contains__(self, position_key: str) -> bool: | ||
"""Checks if position_key already exists.""" | ||
return position_key in self._data | ||
|
||
def __len__(self) -> int: | ||
return len(self._data) | ||
|
||
def __getitem__(self, position_key: str) -> BaseFOV: | ||
"""FOV key position to FOV object.""" | ||
return self._data[position_key] | ||
|
||
def __iter__(self) -> Iterable[tuple[str, BaseFOV]]: | ||
"""Iterates over pairs of keys and FOVs.""" | ||
return self._data.items() | ||
|
||
def __enter__(self) -> FOVDict: | ||
"""Open the underlying file and return self.""" | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc_val: Optional[BaseException], | ||
exc_tb: Optional[TracebackType], | ||
) -> bool: | ||
"""Close the files.""" | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import pytest | ||
|
||
from iohub.fov import BaseFOV, FOVDict | ||
|
||
|
||
class FOV(BaseFOV): | ||
def __init__(self, axes: list[str]) -> None: | ||
super().__init__() | ||
self._axes = axes | ||
|
||
@property | ||
def root(self) -> Path: | ||
return Path() | ||
|
||
@property | ||
def axes_names(self) -> list[str]: | ||
return self._axes | ||
|
||
@property | ||
def channel_names(self) -> list[str]: | ||
return [] | ||
|
||
def __getitem__(self, key: Any) -> Any: | ||
pass | ||
|
||
@property | ||
def dtype(self) -> Any: | ||
pass | ||
|
||
@property | ||
def shape(self) -> Any: | ||
pass | ||
|
||
@property | ||
def zyx_scale(self) -> tuple[float, float, float]: | ||
return (1.0,) * 3 | ||
|
||
@property | ||
def t_scale(self) -> float: | ||
raise 1.0 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"axes,missing", | ||
[ | ||
(["T", "C", "Z", "Y", "X"], []), | ||
(["time", "z", "y", "x"], [1]), | ||
(["channels", "Y", "X"], [0, 2]), | ||
], | ||
) | ||
def test_missing_axes(axes: list[str], missing: list[int]) -> None: | ||
fov = FOV(axes) | ||
assert fov._missing_axes() == missing | ||
|
||
shape = (10,) * len(axes) | ||
padded_shape = fov._pad_missing_axes(shape, 1) | ||
assert len(padded_shape) == 5 | ||
|
||
for i, s in enumerate(padded_shape): | ||
if i in missing: | ||
assert s == 1 | ||
else: | ||
assert s == 10 | ||
|
||
|
||
def test_fov_dict() -> None: | ||
|
||
good_collection = FOVDict( | ||
{ | ||
"488": FOV(["y", "x"]), | ||
"561": FOV(["y", "x"]), | ||
}, | ||
mask=FOV(["c", "x", "y"]), | ||
) | ||
|
||
assert len(good_collection) == 3 | ||
assert "488" in good_collection | ||
assert good_collection["mask"] is not None | ||
|
||
with pytest.raises(TypeError): | ||
del good_collection["561"] | ||
|
||
with pytest.raises(TypeError): | ||
good_collection["segmentation"] = FOV(["x", "y"]) | ||
|
||
with pytest.raises(TypeError): | ||
FOVDict({488: FOV(["y", "x"])}) | ||
|
||
with pytest.raises(TypeError): | ||
FOVDict(mask=[1, 2, 3]) |