Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal reader output base classes #132

Merged
merged 18 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions iohub/fov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping
from pathlib import Path
from types import TracebackType
from typing import Any, Iterable, Optional, Type, Union

import numpy as np
from numpy.typing import ArrayLike

_AXES_PREFIX = ["T", "C", "Z", "Y", "X"]


class BaseFOV(ABC):
@property
@abstractmethod
def root(self) -> Path:
raise NotImplementedError

@property
@abstractmethod
def axes_names(self) -> list[str]:
raise NotImplementedError

@property
@abstractmethod
def channel_names(self) -> list[str]:
raise NotImplementedError

def channel_index(self, key: str) -> int:
"""Return index of given channel."""
return self.channels.index(key)

def _missing_axes(self) -> list[int]:
"""Return sorted indices of missing axes."""
if len(self.axes_names) == 5:
return []

elif len(self.axes_names) > 5:
raise ValueError(
f"{self.__name__} does not support more than 5 axes. "
f"Found {len(self.axes_names)}"
)

axes = set(ax[:1].upper() for ax in self.axes_names)

missing = []
for i, ax in enumerate(_AXES_PREFIX):
if ax not in axes:
missing.append(i)

return missing

def _pad_missing_axes(
self,
seq: Union[list[Any], tuple[Any]],
value: Any,
) -> Union[list[Any], tuple[Any]]:
"""Pads ``seq`` with ``value`` in the missing axes positions."""

if isinstance(seq, tuple):
is_tuple = True
seq = list(seq)
else:
is_tuple = False

for i in self._missing_axes():
seq.insert(i, value)

if is_tuple:
seq = tuple(seq)

assert len(seq) == len(_AXES_PREFIX)
JoOkuma marked this conversation as resolved.
Show resolved Hide resolved

return seq

@property
@abstractmethod
def shape(self) -> tuple[int, int, int, int, int]:
raise NotImplementedError

@abstractmethod
def __getitem__(
self,
key: Union[int, slice, tuple[Union[int, slice], ...]],
) -> ArrayLike:
"""
Returned object must support the ``__array__`` interface,
so that ``np.asarray(...)`` will work.
"""
raise NotImplementedError

@property
def ndim(self) -> int:
return 5

@property
@abstractmethod
def dtype(self) -> np.dtype:
raise NotImplementedError

@property
@abstractmethod
def zyx_scale(self) -> tuple[float, float, float]:
"""Helper function for FOV spatial scale (micrometer)."""
raise NotImplementedError

@property
@abstractmethod
def t_scale(self) -> float:
"""Helper function for FOV time scale (seconds)."""
raise NotImplementedError

def __eq__(self, other: BaseFOV) -> bool:
if not isinstance(other, BaseFOV):
return False
return self.root.absolute() == other.root.absolute()


class BaseFOVMapping(Mapping):
@abstractmethod
def __enter__(self) -> BaseFOVMapping:
"""Open the underlying file and return self."""
raise NotImplementedError

@abstractmethod
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Close the files."""
raise NotImplementedError

@abstractmethod
def __contains__(self, position_key: str) -> bool:
"""Check if a position is present in the collection."""
raise NotImplementedError

@abstractmethod
def __len__(self) -> int:
raise NotImplementedError

@abstractmethod
def __getitem__(self, position_key: str) -> BaseFOV:
"""FOV key position to FOV object."""
raise NotImplementedError

@abstractmethod
def __iter__(self) -> Iterable[tuple[str, BaseFOV]]:
"""Iterates over pairs of keys and FOVs."""
raise NotImplementedError


class FOVDict(BaseFOVMapping):
"""
Basic implementation of a mapping of strings to BaseFOVs.
"""

def __init__(
self,
data_dict: Optional[dict[str, BaseFOV]] = None,
**kwargs,
) -> None:
super().__init__()
self._data = {}

if data_dict is not None:
for key, fov in data_dict.items():
self._safe_insert(key, fov)

for key, fov in kwargs.items():
self._safe_insert(key, fov)

def _safe_insert(self, key: str, value: BaseFOV) -> None:
"""Checks if types are correct and key is unique."""
if not isinstance(key, str):
raise TypeError(
f"{self.__class__.__name__} key must be str. "
f"Found {key} with type {type(key)}"
)

if not isinstance(value, BaseFOV):
raise TypeError(
f"{self.__class__.__name__} value must subclass BaseFOV. "
f"Found {key} with value type {type(value)}"
)

if key in self:
ziw-liu marked this conversation as resolved.
Show resolved Hide resolved
raise KeyError(f"{key} already exists.")

self._data[key] = value

def __contains__(self, position_key: str) -> bool:
"""Checks if position_key already exists."""
return position_key in self._data

def __len__(self) -> int:
return len(self._data)

def __getitem__(self, position_key: str) -> BaseFOV:
"""FOV key position to FOV object."""
return self._data[position_key]

def __iter__(self) -> Iterable[tuple[str, BaseFOV]]:
"""Iterates over pairs of keys and FOVs."""
return self._data.items()

def __enter__(self) -> FOVDict:
"""Open the underlying file and return self."""
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Close the files."""
return True
ziw-liu marked this conversation as resolved.
Show resolved Hide resolved
93 changes: 93 additions & 0 deletions tests/fov/test_fov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from pathlib import Path
from typing import Any

import pytest

from iohub.fov import BaseFOV, FOVDict


class FOV(BaseFOV):
def __init__(self, axes: list[str]) -> None:
super().__init__()
self._axes = axes

@property
def root(self) -> Path:
return Path()

@property
def axes_names(self) -> list[str]:
return self._axes

@property
def channel_names(self) -> list[str]:
return []

def __getitem__(self, key: Any) -> Any:
pass

@property
def dtype(self) -> Any:
pass

@property
def shape(self) -> Any:
pass

@property
def zyx_scale(self) -> tuple[float, float, float]:
return (1.0,) * 3

@property
def t_scale(self) -> float:
raise 1.0


@pytest.mark.parametrize(
"axes,missing",
[
(["T", "C", "Z", "Y", "X"], []),
(["time", "z", "y", "x"], [1]),
(["channels", "Y", "X"], [0, 2]),
],
)
def test_missing_axes(axes: list[str], missing: list[int]) -> None:
fov = FOV(axes)
assert fov._missing_axes() == missing

shape = (10,) * len(axes)
padded_shape = fov._pad_missing_axes(shape, 1)
assert len(padded_shape) == 5

for i, s in enumerate(padded_shape):
if i in missing:
assert s == 1
else:
assert s == 10


def test_fov_dict() -> None:

good_collection = FOVDict(
{
"488": FOV(["y", "x"]),
"561": FOV(["y", "x"]),
},
mask=FOV(["c", "x", "y"]),
)

assert len(good_collection) == 3
assert "488" in good_collection
assert good_collection["mask"] is not None

with pytest.raises(TypeError):
del good_collection["561"]

with pytest.raises(TypeError):
good_collection["segmentation"] = FOV(["x", "y"])

with pytest.raises(TypeError):
FOVDict({488: FOV(["y", "x"])})

with pytest.raises(TypeError):
FOVDict(mask=[1, 2, 3])