diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/README.md b/src/careamics/lightning/callbacks/prediction_writer_callback/README.md new file mode 100644 index 00000000..7f39794a --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/README.md @@ -0,0 +1,119 @@ +# `PredictionWriterCallback` + +## Class diagram for `PredictionWriteCallback` related classes +```mermaid +classDiagram + PredictionWriterCallback*--WriteStrategy : composition + WriteStrategy<--WriteTiles : implements + WriteStrategy<--WriteImage : implements + WriteStrategy<--WriteZarrTiles : implements + WriteTiles*--TileCache : composition + WriteTiles*--SampleCache : composition + WriteImage*--SampleCache : composition + + class PredictionWriterCallback + PredictionWriterCallback : +bool writing_predictions + PredictionWriterCallback : +WriteStrategy write_strategy + PredictionWriterCallback : +write_on_batch_end(...) + PredictionWriterCallback : +on_predict_epoch_end(...) + + class WriteStrategy + <> WriteStrategy + WriteStrategy : +write_batch(...)* + WriteStrategy : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file)* + WriteStrategy : +reset()* + + class WriteTiles + WriteTiles : +WriteFunc write_func + + WriteTiles : +TileCache tile_cache + WriteTiles : +SampleCache sample_cache + WriteTiles : +write_batch(...) + WriteTiles : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file) + WriteTiles : +reset() + + class WriteImage + WriteImage : +WriteFunc write_func + WriteImage : +SampleCache sample_cache + WriteImage : +write_batch(...) + WriteImage : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file) + WriteImage : +reset() + + class WriteZarrTiles + WriteZarrTiles : +write_batch(...) NotImplemented + WriteZarrTiles : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file) NotImplemented + WriteZarrTiles : +reset() NotImplemented + + class TileCache + TileCache : +list~NDArray~ array_cache + TileCache : +list~TileInformation~ tile_info_cache + TileCache : +add(NDArray, list~TileInformation~ item) + TileCache : +has_last_tile() bool + TileCache : +pop_image_tiles() NDArray, list~TileInformation~ + TileCache : +reset() + + class SampleCache + SampleCache : +list~int~ n_samples_per_file + SampleCache : +Iterator n_samples_iter + SampleCache : +int n_samples + SampleCache : +sample_cache list~NDArray~ + SampleCache : +add(NDArray item) + SampleCache : +has_all_file_samples() bool + SampleCache : +pop_file_samples() list~NDArray~ + SampleCache : +reset() +``` + +## Sequence diagram for writing tiles + +```mermaid +sequenceDiagram + participant Trainer + participant PredictionWriterCallback + participant WriteTiles + participant TileCache + + Trainer->>PredictionWriterCallback: write_on_batch_end(batch, ...) + activate PredictionWriterCallback + activate PredictionWriterCallback + PredictionWriterCallback->>WriteTiles: write_batch(batch, ...) + activate WriteTiles + activate WriteTiles + activate WriteTiles + WriteTiles->>TileCache: add(batch) + activate TileCache + WriteTiles ->> TileCache: has_last_tile() + TileCache -->> WriteTiles: True/False + deactivate TileCache + alt If does not have last tile + WriteTiles -->> PredictionWriterCallback: return + deactivate WriteTiles + PredictionWriterCallback -->> Trainer: return + deactivate PredictionWriterCallback + else If has last tile + WriteTiles ->> TileCache: pop_image_tiles() + activate TileCache + TileCache -->> WriteTiles: tiles, tile_infos + deactivate TileCache + Note right of WriteTiles: Tiles are stitched to create prediction_image. + WriteTiles ->> SampleCache: add(prediction_image) + activate SampleCache + WriteTiles ->> SampleCache: has_all_file_samples() + SampleCache -->> WriteTiles: True/False + deactivate SampleCache + alt If does not have all samples + WriteTiles -->> PredictionWriterCallback: return + deactivate WriteTiles + else If has all samples + WriteTiles ->> SampleCache: pop_file_samples() + activate SampleCache + SampleCache -->> WriteTiles: samples + deactivate SampleCache + Note right of WriteTiles: Concatenated samples are written to disk. + WriteTiles -->> PredictionWriterCallback: return + deactivate WriteTiles + end + PredictionWriterCallback -->> Trainer: return + deactivate PredictionWriterCallback + end + +``` \ No newline at end of file diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py b/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py index ffbc1209..331b869b 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py @@ -5,14 +5,14 @@ "create_write_strategy", "WriteStrategy", "WriteImage", - "CacheTiles", + "WriteTiles", "WriteTilesZarr", "select_write_extension", "select_write_func", ] from .prediction_writer_callback import PredictionWriterCallback -from .write_strategy import CacheTiles, WriteImage, WriteStrategy, WriteTilesZarr +from .write_strategy import WriteImage, WriteStrategy, WriteTiles, WriteTilesZarr from .write_strategy_factory import ( create_write_strategy, select_write_extension, diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py b/src/careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py deleted file mode 100644 index 9da2ba6b..00000000 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Module containing file path utilities for `WriteStrategy` to use.""" - -from pathlib import Path -from typing import Union - -from careamics.dataset import IterablePredDataset, IterableTiledPredDataset - - -# TODO: move to datasets package ? -def get_sample_file_path( - dataset: Union[IterableTiledPredDataset, IterablePredDataset], sample_id: int -) -> Path: - """ - Get the file path for a particular sample. - - Parameters - ---------- - dataset : IterableTiledPredDataset or IterablePredDataset - Dataset. - sample_id : int - Sample ID, the index of the file in the dataset `dataset`. - - Returns - ------- - Path - The file path corresponding to the sample with the ID `sample_id`. - """ - return dataset.data_files[sample_id] - - -def create_write_file_path( - dirpath: Path, file_path: Path, write_extension: str -) -> Path: - """ - Create the file name for the output file. - - Takes the original file path, changes the directory to `dirpath` and changes - the extension to `write_extension`. - - Parameters - ---------- - dirpath : pathlib.Path - The output directory to write file to. - file_path : pathlib.Path - The original file path. - write_extension : str - The extension that output files should have. - - Returns - ------- - Path - The output file path. - """ - file_name = Path(file_path.stem).with_suffix(write_extension) - file_path = dirpath / file_name - return file_path diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py b/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py index 8a73c12e..030d6f61 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py @@ -7,7 +7,6 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import BasePredictionWriter -from torch.utils.data import DataLoader from careamics.dataset import ( IterablePredDataset, @@ -127,7 +126,7 @@ def from_write_func_params( ) return cls(write_strategy=write_strategy, dirpath=dirpath) - def _init_dirpath(self, dirpath): + def _init_dirpath(self, dirpath: Union[Path, str]): """ Initialize directory path. Should only be called from `__init__`. @@ -161,6 +160,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non Stage of training e.g. 'predict', 'fit', 'validate'. """ super().setup(trainer, pl_module, stage) + # TODO: move directory creation to hook on_predict_epoch_start if stage == "predict": # make prediction output directory logger.info("Making prediction output directory.") @@ -202,25 +202,6 @@ def write_on_batch_end( if not self.writing_predictions: return - dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders - dataloader: DataLoader = ( - dataloaders[dataloader_idx] - if isinstance(dataloaders, list) - else dataloaders - ) - dataset: ValidPredDatasets = dataloader.dataset - if not ( - isinstance(dataset, IterablePredDataset) - or isinstance(dataset, IterableTiledPredDataset) - ): - # Note: Error will be raised before here from the source type - # This is for extra redundancy of errors. - raise TypeError( - "Prediction dataset has to be `IterableTiledPredDataset` or " - "`IterablePredDataset`. Cannot be `InMemoryPredDataset` because " - "filenames are taken from the original file." - ) - self.write_strategy.write_batch( trainer=trainer, pl_module=pl_module, @@ -231,3 +212,21 @@ def write_on_batch_end( dataloader_idx=dataloader_idx, dirpath=self.dirpath, ) + + def on_predict_epoch_end( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """ + Lightning hook called at the end of prediction. + + Resets write_strategy. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer. + pl_module : LightningModule + PyTorch Lightning module. + """ + # reset write_strategy to prevent bugs if trainer.predict is called twice. + self.write_strategy.reset() diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py deleted file mode 100644 index 9b298da1..00000000 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +++ /dev/null @@ -1,398 +0,0 @@ -"""Module containing different strategies for writing predictions.""" - -from pathlib import Path -from typing import Any, Optional, Protocol, Sequence, Union - -import numpy as np -from numpy.typing import NDArray -from pytorch_lightning import LightningModule, Trainer -from torch.utils.data import DataLoader - -from careamics.config.tile_information import TileInformation -from careamics.dataset import IterablePredDataset, IterableTiledPredDataset -from careamics.file_io import WriteFunc -from careamics.prediction_utils import stitch_prediction_single - -from .file_path_utils import create_write_file_path, get_sample_file_path - - -class WriteStrategy(Protocol): - """Protocol for write strategy classes.""" - - def write_batch( - self, - trainer: Trainer, - pl_module: LightningModule, - prediction: Any, # TODO: change to expected type - batch_indices: Optional[Sequence[int]], - batch: Any, # TODO: change to expected type - batch_idx: int, - dataloader_idx: int, - dirpath: Path, - ) -> None: - """ - WriteStrategy subclasses must contain this function to write a batch. - - Parameters - ---------- - trainer : Trainer - PyTorch Lightning Trainer. - pl_module : LightningModule - PyTorch Lightning LightningModule. - prediction : Any - Predictions on `batch`. - batch_indices : sequence of int - Indices identifying the samples in the batch. - batch : Any - Input batch. - batch_idx : int - Batch index. - dataloader_idx : int - Dataloader index. - dirpath : Path - Path to directory to save predictions to. - """ - - -class CacheTiles(WriteStrategy): - """ - A write strategy that will cache tiles. - - Tiles are cached until a whole image is predicted on. Then the stitched - prediction is saved. - - Parameters - ---------- - write_func : WriteFunc - Function used to save predictions. - write_extension : str - Extension added to prediction file paths. - write_func_kwargs : dict of {str: Any} - Extra kwargs to pass to `write_func`. - - Attributes - ---------- - write_func : WriteFunc - Function used to save predictions. - write_extension : str - Extension added to prediction file paths. - write_func_kwargs : dict of {str: Any} - Extra kwargs to pass to `write_func`. - tile_cache : list of numpy.ndarray - Tiles cached for stitching prediction. - tile_info_cache : list of TileInformation - Cached tile information for stitching prediction. - """ - - def __init__( - self, - write_func: WriteFunc, - write_extension: str, - write_func_kwargs: dict[str, Any], - ) -> None: - """ - A write strategy that will cache tiles. - - Tiles are cached until a whole image is predicted on. Then the stitched - prediction is saved. - - Parameters - ---------- - write_func : WriteFunc - Function used to save predictions. - write_extension : str - Extension added to prediction file paths. - write_func_kwargs : dict of {str: Any} - Extra kwargs to pass to `write_func`. - """ - super().__init__() - - self.write_func: WriteFunc = write_func - self.write_extension: str = write_extension - self.write_func_kwargs: dict[str, Any] = write_func_kwargs - - # where tiles will be cached until a whole image has been predicted - self.tile_cache: list[NDArray] = [] - self.tile_info_cache: list[TileInformation] = [] - - @property - def last_tiles(self) -> list[bool]: - """ - List of bool to determine whether each tile in the cache is the last tile. - - Returns - ------- - list of bool - Whether each tile in the tile cache is the last tile. - """ - return [tile_info.last_tile for tile_info in self.tile_info_cache] - - def write_batch( - self, - trainer: Trainer, - pl_module: LightningModule, - prediction: tuple[NDArray, list[TileInformation]], - batch_indices: Optional[Sequence[int]], - batch: tuple[NDArray, list[TileInformation]], - batch_idx: int, - dataloader_idx: int, - dirpath: Path, - ) -> None: - """ - Cache tiles until the last tile is predicted; save the stitched prediction. - - Parameters - ---------- - trainer : Trainer - PyTorch Lightning Trainer. - pl_module : LightningModule - PyTorch Lightning LightningModule. - prediction : Any - Predictions on `batch`. - batch_indices : sequence of int - Indices identifying the samples in the batch. - batch : Any - Input batch. - batch_idx : int - Batch index. - dataloader_idx : int - Dataloader index. - dirpath : Path - Path to directory to save predictions to. - """ - dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders - dataloader: DataLoader = ( - dataloaders[dataloader_idx] - if isinstance(dataloaders, list) - else dataloaders - ) - dataset: IterableTiledPredDataset = dataloader.dataset - if not isinstance(dataset, IterableTiledPredDataset): - raise TypeError("Prediction dataset is not `IterableTiledPredDataset`.") - - # cache tiles (batches are split into single samples) - self.tile_cache.extend(np.split(prediction[0], prediction[0].shape[0])) - self.tile_info_cache.extend(prediction[1]) - - # save stitched prediction - if self._has_last_tile(): - - # get image tiles and remove them from the cache - tiles, tile_infos = self._get_image_tiles() - self._clear_cache() - - # stitch prediction - prediction_image = stitch_prediction_single( - tiles=tiles, tile_infos=tile_infos - ) - - # write prediction - sample_id = tile_infos[0].sample_id # need this to select correct file name - input_file_path = get_sample_file_path(dataset=dataset, sample_id=sample_id) - file_path = create_write_file_path( - dirpath=dirpath, - file_path=input_file_path, - write_extension=self.write_extension, - ) - self.write_func( - file_path=file_path, img=prediction_image[0], **self.write_func_kwargs - ) - - def _has_last_tile(self) -> bool: - """ - Whether a last tile is contained in the cached tiles. - - Returns - ------- - bool - Whether a last tile is contained in the cached tiles. - """ - return any(self.last_tiles) - - def _clear_cache(self) -> None: - """Remove the tiles in the cache up to the first last tile.""" - index = self._last_tile_index() - self.tile_cache = self.tile_cache[index + 1 :] - self.tile_info_cache = self.tile_info_cache[index + 1 :] - - def _last_tile_index(self) -> int: - """ - Find the index of the last tile in the tile cache. - - Returns - ------- - int - Index of last tile. - - Raises - ------ - ValueError - If there is no last tile in the tile cache. - """ - last_tiles = self.last_tiles - if not any(last_tiles): - raise ValueError("No last tile in the tile cache.") - index = np.where(last_tiles)[0][0] - return index - - def _get_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]: - """ - Get the tiles corresponding to a single image. - - Returns - ------- - tuple of (list of numpy.ndarray, list of TileInformation) - Tiles and tile information to stitch together a full image. - """ - index = self._last_tile_index() - tiles = self.tile_cache[: index + 1] - tile_infos = self.tile_info_cache[: index + 1] - return tiles, tile_infos - - -class WriteTilesZarr(WriteStrategy): - """Strategy to write tiles to Zarr file.""" - - def write_batch( - self, - trainer: Trainer, - pl_module: LightningModule, - prediction: Any, - batch_indices: Optional[Sequence[int]], - batch: Any, - batch_idx: int, - dataloader_idx: int, - dirpath: Path, - ) -> None: - """ - Write tiles to zarr file. - - Parameters - ---------- - trainer : Trainer - PyTorch Lightning Trainer. - pl_module : LightningModule - PyTorch Lightning LightningModule. - prediction : Any - Predictions on `batch`. - batch_indices : sequence of int - Indices identifying the samples in the batch. - batch : Any - Input batch. - batch_idx : int - Batch index. - dataloader_idx : int - Dataloader index. - dirpath : Path - Path to directory to save predictions to. - - Raises - ------ - NotImplementedError - """ - raise NotImplementedError - - -class WriteImage(WriteStrategy): - """ - A strategy for writing image predictions (i.e. un-tiled predictions). - - Parameters - ---------- - write_func : WriteFunc - Function used to save predictions. - write_extension : str - Extension added to prediction file paths. - write_func_kwargs : dict of {str: Any} - Extra kwargs to pass to `write_func`. - - Attributes - ---------- - write_func : WriteFunc - Function used to save predictions. - write_extension : str - Extension added to prediction file paths. - write_func_kwargs : dict of {str: Any} - Extra kwargs to pass to `write_func`. - """ - - def __init__( - self, - write_func: WriteFunc, - write_extension: str, - write_func_kwargs: dict[str, Any], - ) -> None: - """ - A strategy for writing image predictions (i.e. un-tiled predictions). - - Parameters - ---------- - write_func : WriteFunc - Function used to save predictions. - write_extension : str - Extension added to prediction file paths. - write_func_kwargs : dict of {str: Any} - Extra kwargs to pass to `write_func`. - """ - super().__init__() - - self.write_func: WriteFunc = write_func - self.write_extension: str = write_extension - self.write_func_kwargs: dict[str, Any] = write_func_kwargs - - def write_batch( - self, - trainer: Trainer, - pl_module: LightningModule, - prediction: NDArray, - batch_indices: Optional[Sequence[int]], - batch: NDArray, - batch_idx: int, - dataloader_idx: int, - dirpath: Path, - ) -> None: - """ - Save full images. - - Parameters - ---------- - trainer : Trainer - PyTorch Lightning Trainer. - pl_module : LightningModule - PyTorch Lightning LightningModule. - prediction : Any - Predictions on `batch`. - batch_indices : sequence of int - Indices identifying the samples in the batch. - batch : Any - Input batch. - batch_idx : int - Batch index. - dataloader_idx : int - Dataloader index. - dirpath : Path - Path to directory to save predictions to. - - Raises - ------ - TypeError - If trainer prediction dataset is not `IterablePredDataset`. - """ - dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders - dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls - ds: IterablePredDataset = dl.dataset - if not isinstance(ds, IterablePredDataset): - raise TypeError("Prediction dataset is not `IterablePredDataset`.") - - for i in range(prediction.shape[0]): - prediction_image = prediction[0] - sample_id = batch_idx * dl.batch_size + i - input_file_path = get_sample_file_path(dataset=ds, sample_id=sample_id) - file_path = create_write_file_path( - dirpath=dirpath, - file_path=input_file_path, - write_extension=self.write_extension, - ) - self.write_func( - file_path=file_path, img=prediction_image, **self.write_func_kwargs - ) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py new file mode 100644 index 00000000..43feb7a9 --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py @@ -0,0 +1,13 @@ +"""Write strategies for the prediciton writer callback.""" + +__all__ = [ + "WriteStrategy", + "WriteTiles", + "WriteImage", + "WriteTilesZarr", +] + +from .protocol import WriteStrategy +from .write_image import WriteImage +from .write_tiles import WriteTiles +from .write_tiles_zarr import WriteTilesZarr diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py new file mode 100644 index 00000000..43b973f4 --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py @@ -0,0 +1,175 @@ +"""Utility classes, for caching data, used in the write strategies.""" + +from typing import Optional + +import numpy as np +from numpy.typing import NDArray + +from careamics.config.tile_information import TileInformation + + +class TileCache: + """ + Logic to cache tiles, then pop tiles when tiles from a full image have been stored. + + Attributes + ---------- + array_cache : list[numpy.ndarray] + The tile arrays with the dimensions SC(Z)YX. + tile_info_cache : list[TileInformation] + The corresponding tile information for each tile. + """ + + def __init__(self): + """Logic to cache tiles, and pop tiles when a full set of have been stored.""" + self.array_cache: list[NDArray] = [] + self.tile_info_cache: list[TileInformation] = [] + + def add(self, item: tuple[NDArray, list[TileInformation]]): + """ + Add another batch to the cache. + + Parameters + ---------- + item : tuple of (numpy.ndarray, list[TileInformation]) + Tuple where the first element is a concatenated set of tiles, and the + second element is a list of each corresponding `TileInformation`. + """ + self.array_cache.extend(np.split(item[0], item[0].shape[0])) + self.tile_info_cache.extend(item[1]) + + def has_last_tile(self) -> bool: + """ + Determine whether the current cache contains the last tile of a sample. + + Returns + ------- + bool + Whether the last tile is contained in the cache. + """ + return any(tile_info.last_tile for tile_info in self.tile_info_cache) + + def pop_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]: + """ + Pop the tiles that will create a full image from the cache. + + I.e. The tiles belonging to a full image will be removed from the cache but + returned by this function call. + + Returns + ------- + list of numpy.ndarray + A list of tiles with the dimensions SC(Z)YX. + list of TileInformation + A list of corresponding tile information. + + Raises + ------ + ValueError + If the tiles belonging to a full image are not contained in the cache, i.e. + if the cache does not contain the last tile of an image. + """ + is_last_tile = [tile_info.last_tile for tile_info in self.tile_info_cache] + if not any(is_last_tile): + raise ValueError("No last tile in cache.") + + index = np.where(is_last_tile)[0][0] + # get image tiles + tiles = self.array_cache[: index + 1] + tile_infos = self.tile_info_cache[: index + 1] + # remove image tiles from list + self.array_cache = self.array_cache[index + 1 :] + self.tile_info_cache = self.tile_info_cache[index + 1 :] + + return tiles, tile_infos + + def reset(self): + """Reset the cache. Remove all tiles and tile information from the cache.""" + self.array_cache = [] + self.tile_info_cache = [] + + +class SampleCache: + """ + Logic to cache samples until they can be concatenated together to create a file. + + Parameters + ---------- + n_samples_per_file : list[int] + A list that contains the number of samples that will be contained in each + file. There should be `n` elements in the list for `n` files intended to + be created. + """ + + def __init__(self, n_samples_per_file: list[int]): + """ + Logic to cache samples until they can be concatenated together to create a file. + + Parameters + ---------- + n_samples_per_file : list[int] + A list that contains the number of samples that will be contained in each + file. There should be `n` elements in the list for `n` files intended to + be created. + """ + self.n_samples_per_file: list[int] = n_samples_per_file + self.n_samples_iter = iter(self.n_samples_per_file) + # n_samples will be set to None once iterated through each element + self.n_samples: Optional[int] = next(self.n_samples_iter) + self.sample_cache: list[NDArray] = [] + + def add(self, item: NDArray): + """ + Add a sample to the cache. + + Parameters + ---------- + item : numpy.ndarray + A set of predicted samples. + """ + self.sample_cache.extend(np.split(item, item.shape[0])) + + def has_all_file_samples(self) -> bool: + """ + Determine if all the samples for the current file are contained in the cache. + + Returns + ------- + bool + Whether all the samples are contained in the cache. + """ + if self.n_samples is None: + raise ValueError( + "Number of samples for current file is unknown. Reached the end of the " + "given list of samples per file, or a list has not been given." + ) + return len(self.sample_cache) >= self.n_samples + + def pop_file_samples(self) -> list[NDArray]: + """ + Pop from the cache the samples required for the current file to be created. + + Returns + ------- + list of numpy.ndarray + A list of samples to concatenate together into a file. + """ + if not self.has_all_file_samples(): + raise ValueError( + "Do not have all the samples belonging to the current file." + ) + + samples = self.sample_cache[: self.n_samples] + self.sample_cache = self.sample_cache[self.n_samples :] + + try: + self.n_samples = next(self.n_samples_iter) + except StopIteration: + self.n_samples = None + + return samples + + def reset(self): + """Reset the cache. Remove all the samples from the cache.""" + self.n_samples_iter = iter(self.n_samples_per_file) + self.sample_cache: list[NDArray] = [] diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py new file mode 100644 index 00000000..ded4376e --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py @@ -0,0 +1,74 @@ +"""Module containing the protocol that defines the WriteStrategy interface.""" + +from pathlib import Path +from typing import Any, Optional, Protocol, Sequence + +from pytorch_lightning import LightningModule, Trainer + + +class WriteStrategy(Protocol): + """ + Protocol for write strategy classes. + + A `WriteStrategy` is an object that will be an attribute in the + `PredictionWriterCallback`; it will determine how predictions will be saved. + `WriteStrategy`s must be interchangeable so they must follow the interface set out + in this `Protocol` class. + """ + + def write_batch( + self, + trainer: Trainer, + pl_module: LightningModule, + prediction: Any, # TODO: change to expected type + batch_indices: Optional[Sequence[int]], + batch: Any, # TODO: change to expected type + batch_idx: int, + dataloader_idx: int, + dirpath: Path, + ) -> None: + """ + WriteStrategy subclasses must contain this function to write a batch. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning Trainer. + pl_module : LightningModule + PyTorch Lightning LightningModule. + prediction : Any + Predictions on `batch`. + batch_indices : sequence of int + Indices identifying the samples in the batch. + batch : Any + Input batch. + batch_idx : int + Batch index. + dataloader_idx : int + Dataloader index. + dirpath : Path + Path to directory to save predictions to. + """ + + def set_file_data( + self, write_filenames: list[str], n_samples_per_file: list[int] + ) -> None: + """ + Set file information after the `WriteStrategy` has been initialized. + + Parameters + ---------- + write_filenames : list[str] + A list of filenames to save to. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). + """ + + def reset(self) -> None: + """ + Reset internal state (attributes) of a `WriteStrategy` instance. + + This is to unexpected behaviour if a `WriteStrategy` instance is used twice. + """ diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py new file mode 100644 index 00000000..bdd44834 --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -0,0 +1,192 @@ +"""Module containing write strategy for when batches contain full images.""" + +from pathlib import Path +from typing import Any, Iterator, Optional, Sequence, Union + +import numpy as np +from numpy.typing import NDArray +from pytorch_lightning import LightningModule, Trainer +from torch.utils.data import DataLoader + +from careamics.dataset import IterablePredDataset +from careamics.file_io import WriteFunc + +from .caches import SampleCache + + +class WriteImage: + """ + A strategy for writing image predictions (i.e. not tiled predictions). + + Parameters + ---------- + write_func : WriteFunc + Function used to save predictions. + write_extension : str + Extension added to prediction file paths. + write_func_kwargs : dict of {str: Any} + Extra kwargs to pass to `write_func`. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list of int + The number of samples in each file, (controls which samples will be + grouped together in each file). + + Attributes + ---------- + write_func : WriteFunc + Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + write_extension : str + Extension added to prediction file paths. + write_func_kwargs : dict of {str: Any} + Extra kwargs to pass to `write_func`. + + """ + + def __init__( + self, + write_func: WriteFunc, + write_extension: str, + write_func_kwargs: dict[str, Any], + write_filenames: Optional[list[str]], + n_samples_per_file: Optional[list[int]], + ) -> None: + """ + A strategy for writing image predictions (i.e. un-tiled predictions). + + Parameters + ---------- + write_func : WriteFunc + Function used to save predictions. + write_extension : str + Extension added to prediction file paths. + write_func_kwargs : dict of {str: Any} + Extra kwargs to pass to `write_func`. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list of int + The number of samples in each file, (controls which samples will be + grouped together in each file). + """ + super().__init__() + + self.write_func: WriteFunc = write_func + self.write_extension: str = write_extension + self.write_func_kwargs: dict[str, Any] = write_func_kwargs + + # where samples are stored until a whole file has been predicted + self.sample_cache: Optional[SampleCache] + + self._write_filenames: Optional[list[str]] = write_filenames + self._filename_iter: Optional[Iterator[str]] = ( + iter(write_filenames) if write_filenames is not None else None + ) + + if not ((write_filenames is None) or (n_samples_per_file is None)): + # also creates sample cache + self.set_file_data(write_filenames, n_samples_per_file) + else: + self.sample_cache = None + + def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]): + """ + Set file information after the `WriteImage` strategy has been initialized. + + Parameters + ---------- + write_filenames : list[str] + A list of filenames to save to. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). + """ + if len(write_filenames) != len(n_samples_per_file): + raise ValueError( + "List of filename and list of number of samples per file are not of " + "equal length." + ) + self._write_filenames = write_filenames + self._filename_iter = iter(write_filenames) + self.sample_cache = SampleCache(n_samples_per_file=n_samples_per_file) + + def write_batch( + self, + trainer: Trainer, + pl_module: LightningModule, + prediction: NDArray, + batch_indices: Optional[Sequence[int]], + batch: NDArray, + batch_idx: int, + dataloader_idx: int, + dirpath: Path, + ) -> None: + """ + Save full images. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning Trainer. + pl_module : LightningModule + PyTorch Lightning LightningModule. + prediction : Any + Predictions on `batch`. + batch_indices : sequence of int + Indices identifying the samples in the batch. + batch : Any + Input batch. + batch_idx : int + Batch index. + dataloader_idx : int + Dataloader index. + dirpath : Path + Path to directory to save predictions to. + + Raises + ------ + TypeError + If trainer prediction dataset is not `IterablePredDataset`. + ValueError + If `write_filenames` attribute is `None`. + """ + if self.sample_cache is None: + raise ValueError( + "`SampleCache` has not been created. Call `set_file_data` before " + "calling `write_batch`." + ) + # assert for mypy + assert ( + self._filename_iter is not None + ), "`filename_iter` is `None` should be set by `set_file_data`." + + dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders + dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls + ds: IterablePredDataset = dl.dataset + if not isinstance(ds, IterablePredDataset): + # TODO: change to warning + raise TypeError("Prediction dataset is not `IterablePredDataset`.") + + self.sample_cache.add(prediction) + # early return + if not self.sample_cache.has_all_file_samples(): + return + + # if has all samples in file + samples = self.sample_cache.pop_file_samples() + + # combine + data = np.concatenate(samples) + + # write prediction + file_name = next(self._filename_iter) + file_path = (dirpath / file_name).with_suffix(self.write_extension) + self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) + + def reset(self) -> None: + """Reset internal attributes.""" + self._write_filenames = None + self._filename_iter = None + self.current_file_index = 0 diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py new file mode 100644 index 00000000..ade75cd0 --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -0,0 +1,223 @@ +"""Module containing the "cache tiles" write strategy.""" + +from pathlib import Path +from typing import Any, Iterator, Optional, Sequence, Union + +import numpy as np +from numpy.typing import NDArray +from pytorch_lightning import LightningModule, Trainer +from torch.utils.data import DataLoader + +from careamics.config.tile_information import TileInformation +from careamics.dataset import IterableTiledPredDataset +from careamics.file_io import WriteFunc +from careamics.prediction_utils import stitch_prediction_single + +from .caches import SampleCache, TileCache + + +class WriteTiles: + """ + A write strategy that will cache tiles. + + Tiles are cached until a whole image is predicted on. Then the stitched + prediction is saved. + + Parameters + ---------- + write_func : WriteFunc + Function used to save predictions. + write_extension : str + Extension added to prediction file paths. + write_func_kwargs : dict of {str: Any} + Extra kwargs to pass to `write_func`. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). + + Attributes + ---------- + write_func : WriteFunc + Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + write_extension : str + Extension added to prediction file paths. + write_func_kwargs : dict of {str: Any} + Extra kwargs to pass to `write_func`. + tile_cache : list of numpy.ndarray + Tiles cached for stitching prediction. + tile_info_cache : list of TileInformation + Cached tile information for stitching prediction. + """ + + def __init__( + self, + write_func: WriteFunc, + write_extension: str, + write_func_kwargs: dict[str, Any], + write_filenames: Optional[list[str]], + n_samples_per_file: Optional[list[int]], + ) -> None: + """ + A write strategy that will cache tiles. + + Tiles are cached until a whole image is predicted on. Then the stitched + prediction is saved. + + Parameters + ---------- + write_func : WriteFunc + Function used to save predictions. + write_extension : str + Extension added to prediction file paths. + write_func_kwargs : dict of {str: Any} + Extra kwargs to pass to `write_func`. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). + """ + super().__init__() + + self.write_func: WriteFunc = write_func + self.write_extension: str = write_extension + self.write_func_kwargs: dict[str, Any] = write_func_kwargs + + # where tiles will be cached until a whole image has been predicted + self.tile_cache = TileCache() + # where samples are stored until a whole file has been predicted + self.sample_cache: Optional[SampleCache] + + self._write_filenames: Optional[list[str]] = write_filenames + self._filename_iter: Optional[Iterator[str]] = ( + iter(write_filenames) if write_filenames is not None else None + ) + if write_filenames is not None and n_samples_per_file is not None: + self.set_file_data(write_filenames, n_samples_per_file) + else: + self.sample_cache = None + + def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]): + """ + Set file information after the `WriteTiles` has been initialized. + + Parameters + ---------- + write_filenames : list[str] + A list of filenames to save to. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). + """ + if len(write_filenames) != len(n_samples_per_file): + raise ValueError( + "List of filename and list of number of samples per file are not of " + "equal length." + ) + self._write_filenames = write_filenames + self._filename_iter = iter(write_filenames) + self.sample_cache = SampleCache(n_samples_per_file) + + def write_batch( + self, + trainer: Trainer, + pl_module: LightningModule, + prediction: tuple[NDArray, list[TileInformation]], + batch_indices: Optional[Sequence[int]], + batch: tuple[NDArray, list[TileInformation]], + batch_idx: int, + dataloader_idx: int, + dirpath: Path, + ) -> None: + """ + Cache tiles until the last tile is predicted; save the stitched prediction. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning Trainer. + pl_module : LightningModule + PyTorch Lightning LightningModule. + prediction : Any + Predictions on `batch`. + batch_indices : sequence of int + Indices identifying the samples in the batch. + batch : Any + Input batch. + batch_idx : int + Batch index. + dataloader_idx : int + Dataloader index. + dirpath : Path + Path to directory to save predictions to. + + Raises + ------ + ValueError + If `write_filenames` attribute is `None`. + """ + if self.sample_cache is None: + raise ValueError( + "`SampleCache` has not been created. Call `set_file_data` before " + "calling `write_batch`." + ) + # assert for mypy + assert ( + self._filename_iter is not None + ), "`filename_iter` is `None` should be set by `set_file_data`." + + # TODO: move dataset type check somewhere else + dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders + dataloader: DataLoader = ( + dataloaders[dataloader_idx] + if isinstance(dataloaders, list) + else dataloaders + ) + dataset: IterableTiledPredDataset = dataloader.dataset + if not isinstance(dataset, IterableTiledPredDataset): + raise TypeError("Prediction dataset is not `IterableTiledPredDataset`.") + + self.tile_cache.add(prediction) + + # early return + if not self.tile_cache.has_last_tile(): + return + + # if has last tile + tiles, tile_infos = self.tile_cache.pop_image_tiles() + + # stitch prediction + prediction_image = stitch_prediction_single(tiles=tiles, tile_infos=tile_infos) + + self.sample_cache.add(prediction_image) + + # early return + if not self.sample_cache.has_all_file_samples(): + return + + # if has all samples in file + samples = self.sample_cache.pop_file_samples() + + # combine + data = np.concatenate(samples) + + # write prediction + file_name = next(self._filename_iter) + file_path = (dirpath / file_name).with_suffix(self.write_extension) + self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) + + def reset(self) -> None: + """Reset the internal attributes.""" + self._write_filenames = None + self._filename_iter = None + if self.tile_cache is not None: + self.tile_cache.reset() + if self.sample_cache is not None: + self.sample_cache.reset() diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py new file mode 100644 index 00000000..d3a6dcc3 --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py @@ -0,0 +1,53 @@ +"""Module containing a write strategy for writing tiles directly to zarr datasets.""" + +from pathlib import Path +from typing import Any, Optional, Sequence + +from pytorch_lightning import LightningModule, Trainer + + +class WriteTilesZarr: + """Strategy to write tiles to Zarr file.""" + + def write_batch( + self, + trainer: Trainer, + pl_module: LightningModule, + prediction: Any, + batch_indices: Optional[Sequence[int]], + batch: Any, + batch_idx: int, + dataloader_idx: int, + dirpath: Path, + ) -> None: + """ + Write tiles to zarr file. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning Trainer. + pl_module : LightningModule + PyTorch Lightning LightningModule. + prediction : Any + Predictions on `batch`. + batch_indices : sequence of int + Indices identifying the samples in the batch. + batch : Any + Input batch. + batch_idx : int + Batch index. + dataloader_idx : int + Dataloader index. + dirpath : Path + Path to directory to save predictions to. + + Raises + ------ + NotImplementedError + """ + raise NotImplementedError + + def reset(self) -> None: + """Reset internal attributes.""" + raise NotImplementedError diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py index a9eda4a4..771bc7e3 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py @@ -5,7 +5,7 @@ from careamics.config.support import SupportedData from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func -from .write_strategy import CacheTiles, WriteImage, WriteStrategy +from .write_strategy import WriteImage, WriteStrategy, WriteTiles def create_write_strategy( @@ -14,6 +14,8 @@ def create_write_strategy( write_func: Optional[WriteFunc] = None, write_extension: Optional[str] = None, write_func_kwargs: Optional[dict[str, Any]] = None, + write_filenames: Optional[list[str]] = None, + n_samples_per_file: Optional[list[int]] = None, ) -> WriteStrategy: """ Create a write strategy from convenient parameters. @@ -32,6 +34,12 @@ def create_write_strategy( `write_type` an extension to save the data with must be passed. write_func_kwargs : dict of {str: any}, optional Additional keyword arguments to be passed to the save function. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). Returns ------- @@ -61,16 +69,20 @@ def create_write_strategy( ) write_strategy = WriteImage( write_func=write_func, + write_filenames=write_filenames, write_extension=write_extension, write_func_kwargs=write_func_kwargs, + n_samples_per_file=n_samples_per_file, ) else: - # select CacheTiles or WriteTilesZarr (when implemented) + # select WriteTiles or WriteTilesZarr (when implemented) write_strategy = _create_tiled_write_strategy( write_type=write_type, write_func=write_func, + write_filenames=write_filenames, write_extension=write_extension, write_func_kwargs=write_func_kwargs, + n_samples_per_file=n_samples_per_file, ) return write_strategy @@ -81,11 +93,13 @@ def _create_tiled_write_strategy( write_func: Optional[WriteFunc], write_extension: Optional[str], write_func_kwargs: dict[str, Any], + write_filenames: Optional[list[str]], + n_samples_per_file: Optional[list[int]], ) -> WriteStrategy: """ Create a tiled write strategy. - Either `CacheTiles` for caching tiles until a whole image is predicted or + Either `WriteTiles` for caching tiles until a whole image is predicted or `WriteTilesZarr` for writing tiles directly to disk. Parameters @@ -100,6 +114,12 @@ def _create_tiled_write_strategy( `write_type` an extension to save the data with must be passed. write_func_kwargs : dict of {str: any} Additional keyword arguments to be passed to the save function. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). Returns ------- @@ -122,10 +142,12 @@ def _create_tiled_write_strategy( write_extension = select_write_extension( write_type=write_type, write_extension=write_extension ) - return CacheTiles( + return WriteTiles( write_func=write_func, + write_filenames=write_filenames, write_extension=write_extension, write_func_kwargs=write_func_kwargs, + n_samples_per_file=n_samples_per_file, ) diff --git a/tests/lightning/callbacks/prediction_writer_callback/conftest.py b/tests/lightning/callbacks/prediction_writer_callback/conftest.py new file mode 100644 index 00000000..798cb11b --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/conftest.py @@ -0,0 +1,95 @@ +from typing import Callable +from unittest.mock import Mock + +import numpy as np +import pytest +from numpy.typing import NDArray + +from careamics.config.tile_information import TileInformation +from careamics.dataset.tiling import extract_tiles +from careamics.file_io import WriteFunc +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + caches, +) + + +@pytest.fixture +def write_func(): + """Mock `WriteFunc`.""" + return Mock(spec=WriteFunc) + + +@pytest.fixture +def patch_tile_cache() -> ( + Callable[[caches.TileCache, list[NDArray], list[TileInformation]], None] +): + def inner( + tile_cache: caches.TileCache, + tiles: list[NDArray], + tile_infos: list[TileInformation], + ) -> None: + """ + Patch simulated tile cache into `strategy`. + + Parameters + ---------- + tile_cache : TileCache + Tile cache used in `WriteTiles` write strategy class. + tiles : list of NDArray + Tiles to patch into `strategy.tile_cache`. + tile_infos : list of TileInformation + Corresponding tile information to patch into `strategy.tile_info_cache`. + """ + tile_cache.add((np.concatenate(tiles), tile_infos)) + + return inner + + +@pytest.fixture +def create_tiles() -> Callable[[int], tuple[list[NDArray], list[TileInformation]]]: + def inner(n_samples: int) -> tuple[list[NDArray], list[TileInformation]]: + """ + Create a set of tiles from `n_samples`. + + To create the tiles the following parameters, `tile_size=(4, 4)` and + `overlaps=(2, 2)`, on an input array with shape (`n_samples`, 1, 8, 8); this + results in 9 tiles per sample. + + Parameters + ---------- + n_samples : int + Number of samples to simulate the tiles from. + + Returns + ------- + tuple of (list of NDArray), list of TileInformation)) + Tuple where first element is the list of tiles and second element is a list + of corresponding tile information. + """ + + input_shape = (n_samples, 1, 8, 8) + tile_size = (4, 4) + tile_overlap = (2, 2) + + arr = np.arange(np.prod(input_shape)).reshape(input_shape) + + all_tiles = list(extract_tiles(arr, tile_size, tile_overlap)) + tiles = [output[0] for output in all_tiles] + tile_infos = [output[1] for output in all_tiles] + + return tiles, tile_infos + + return inner + + +@pytest.fixture +def samples() -> tuple[NDArray, list[int]]: + n_samples_per_file = [3, 1, 2] + shapes = [(16, 16), (8, 8), (12, 12)] + sample_set = [] + for n_samples, spatial_shape in zip(n_samples_per_file, shapes): + shape = (n_samples, 1, *spatial_shape) + sample = np.arange(np.prod(shape)).reshape(shape) + sample_set.extend(np.split(sample, n_samples)) + + return sample_set, n_samples_per_file diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py deleted file mode 100644 index da58b4bc..00000000 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ /dev/null @@ -1,326 +0,0 @@ -"""Test `CacheTiles` class.""" - -from pathlib import Path -from unittest.mock import DEFAULT, Mock, patch - -import numpy as np -import pytest -from numpy.typing import NDArray -from pytorch_lightning import LightningModule, Trainer -from torch.utils.data import DataLoader - -from careamics.config.tile_information import TileInformation -from careamics.dataset import IterableTiledPredDataset -from careamics.dataset.tiling import extract_tiles -from careamics.file_io import WriteFunc -from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( - CacheTiles, -) - - -def create_tiles(n_samples: int) -> tuple[list[NDArray], list[TileInformation]]: - """ - Create a set of tiles from `n_samples`. - - To create the tiles the following parameters, `tile_size=(4, 4)` and - `overlaps=(2, 2)`, on an input array with shape (`n_samples`, 1, 8, 8); this - results in 9 tiles per sample. - - Parameters - ---------- - n_samples : int - Number of samples to simulate the tiles from. - - Returns - ------- - tuple of (list of NDArray), list of TileInformation)) - Tuple where first element is the list of tiles and second element is a list - of corresponding tile information. - """ - - input_shape = (n_samples, 1, 8, 8) - tile_size = (4, 4) - tile_overlap = (2, 2) - - arr = np.arange(np.prod(input_shape)).reshape(input_shape) - - all_tiles = list(extract_tiles(arr, tile_size, tile_overlap)) - tiles = [output[0] for output in all_tiles] - tile_infos = [output[1] for output in all_tiles] - - return tiles, tile_infos - - -def patch_tile_cache( - strategy: CacheTiles, tiles: list[NDArray], tile_infos: list[TileInformation] -) -> None: - """ - Patch simulated tile cache into `strategy`. - - Parameters - ---------- - strategy : CacheTiles - Write strategy `CacheTiles`. - tiles : list of NDArray - Tiles to patch into `strategy.tile_cache`. - tile_infos : list of TileInformation - Corresponding tile information to patch into `strategy.tile_info_cache`. - """ - strategy.tile_cache = tiles - strategy.tile_info_cache = tile_infos - - -@pytest.fixture -def write_func(): - """Mock `WriteFunc`.""" - return Mock(spec=WriteFunc) - - -@pytest.fixture -def cache_tiles_strategy(write_func) -> CacheTiles: - """ - Initialized `CacheTiles` class. - - Parameters - ---------- - write_func : WriteFunc - Write function. (Comes from fixture). - - Returns - ------- - CacheTiles - Initialized `CacheTiles` class. - """ - write_extension = ".ext" - write_func_kwargs = {} - return CacheTiles( - write_func=write_func, - write_extension=write_extension, - write_func_kwargs=write_func_kwargs, - ) - - -def test_cache_tiles_init(write_func, cache_tiles_strategy): - """ - Test `CacheTiles` initializes as expected. - """ - assert cache_tiles_strategy.write_func is write_func - assert cache_tiles_strategy.write_extension == ".ext" - assert cache_tiles_strategy.write_func_kwargs == {} - assert cache_tiles_strategy.tile_cache == [] - assert cache_tiles_strategy.tile_info_cache == [] - - -def test_last_tiles(cache_tiles_strategy): - """Test `CacheTiles.last_tile` property.""" - - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - patch_tile_cache(cache_tiles_strategy, tiles, tile_infos) - - last_tile = [False, False, False, False, False, False, False, False, True] - assert cache_tiles_strategy.last_tiles == last_tile - - -def test_write_batch_no_last_tile(cache_tiles_strategy): - """ - Test `CacheTiles.write_batch` when there is no last tile added to the cache. - - Expected behaviour is that the batch is added to the cache. - """ - - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - - # simulate adding a batch that will not contain the last tile - n_tiles = 4 - batch_size = 2 - patch_tile_cache(cache_tiles_strategy, tiles[:n_tiles], tile_infos[:n_tiles]) - next_batch = ( - np.concatenate(tiles[n_tiles : n_tiles + batch_size]), - tile_infos[n_tiles : n_tiles + batch_size], - ) - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - mock_dataset = Mock(spec=IterableTiledPredDataset) - dataloader_idx = 0 - trainer.predict_dataloaders = [Mock(spec=DataLoader)] - trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset - - cache_tiles_strategy.write_batch( - trainer=trainer, - pl_module=Mock(spec=LightningModule), - prediction=next_batch, - batch_indices=Mock(), - batch=next_batch, # does not contain the last tile - batch_idx=3, - dataloader_idx=dataloader_idx, - dirpath="predictions", - ) - - extended_tiles = tiles[: n_tiles + batch_size] - extended_tile_infos = tile_infos[: n_tiles + batch_size] - - assert all( - np.array_equal(extended_tiles[i], cache_tiles_strategy.tile_cache[i]) - for i in range(n_tiles + batch_size) - ) - assert extended_tile_infos == cache_tiles_strategy.tile_info_cache - - -def test_write_batch_last_tile(cache_tiles_strategy): - """ - Test `CacheTiles.write_batch` when there is a last tile added to the cache. - - Expected behaviour is that the cache is cleared and the write func is called. - """ - - # all tiles of 2 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=2) - - # simulate adding a batch that will contain the last tile - n_tiles = 8 - batch_size = 2 - patch_tile_cache(cache_tiles_strategy, tiles[:n_tiles], tile_infos[:n_tiles]) - next_batch = ( - np.concatenate(tiles[n_tiles : n_tiles + batch_size]), - tile_infos[n_tiles : n_tiles + batch_size], - ) - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - mock_dataset = Mock(spec=IterableTiledPredDataset) - dataloader_idx = 0 - trainer.predict_dataloaders = [Mock(spec=DataLoader)] - trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset - - # These functions have their own unit tests, - # so they do not need to be tested again here. - # This is a unit test to isolate functionality of `write_batch.` - with patch.multiple( - "careamics.lightning.callbacks.prediction_writer_callback.write_strategy", - stitch_prediction_single=DEFAULT, - get_sample_file_path=DEFAULT, - create_write_file_path=DEFAULT, - ) as values: - - # mocked functions - mock_stitch_prediction_single = values["stitch_prediction_single"] - mock_get_sample_file_path = values["get_sample_file_path"] - mock_create_write_file_path = values["create_write_file_path"] - - prediction_image = [Mock()] - in_file_path = Path("in_dir/file_path.ext") - out_file_path = Path("out_dir/file_path.in_ext") - mock_stitch_prediction_single.return_value = prediction_image - mock_get_sample_file_path.return_value = in_file_path - mock_create_write_file_path.return_value = out_file_path - - # call write batch - dirpath = "predictions" - cache_tiles_strategy.write_batch( - trainer=trainer, - pl_module=Mock(spec=LightningModule), - prediction=next_batch, - batch_indices=Mock(), - batch=next_batch, # contains the last tile - batch_idx=3, - dataloader_idx=dataloader_idx, - dirpath=dirpath, - ) - - # assert create_write_file_path is called as expected (TODO: necessary ?) - mock_create_write_file_path.assert_called_once_with( - dirpath=dirpath, - file_path=in_file_path, - write_extension=cache_tiles_strategy.write_extension, - ) - # assert write_func is called as expected - cache_tiles_strategy.write_func.assert_called_once_with( - file_path=out_file_path, img=prediction_image[0], **{} - ) - - # Tile of the next image (should remain in the cache) - remaining_tile = tiles[9] - remaining_tile_info = tile_infos[9] - - # assert cache cleared as expected - assert len(cache_tiles_strategy.tile_cache) == 1 - assert np.array_equal(remaining_tile, cache_tiles_strategy.tile_cache[0]) - assert remaining_tile_info == cache_tiles_strategy.tile_info_cache[0] - - -def test_have_last_tile_true(cache_tiles_strategy): - """Test `CacheTiles._have_last_tile` returns true when there is a last tile.""" - - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - patch_tile_cache(cache_tiles_strategy, tiles, tile_infos) - - assert cache_tiles_strategy._has_last_tile() - - -def test_have_last_tile_false(cache_tiles_strategy): - """Test `CacheTiles._have_last_tile` returns false when there is not a last tile.""" - - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - # don't include last tile - patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) - - assert not cache_tiles_strategy._has_last_tile() - - -def test_clear_cache(cache_tiles_strategy): - """ - Test `CacheTiles._clear_cache` removes the tiles up until the first "last tile". - """ - - # all tiles of 2 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=2) - # include first tile from next sample - patch_tile_cache(cache_tiles_strategy, tiles[:10], tile_infos[:10]) - - cache_tiles_strategy._clear_cache() - - assert len(cache_tiles_strategy.tile_cache) == 1 - assert np.array_equal(cache_tiles_strategy.tile_cache[0], tiles[9]) - assert cache_tiles_strategy.tile_info_cache[0] == tile_infos[9] - - -def test_last_tile_index(cache_tiles_strategy): - """Test `CacheTiles._last_tile_index` returns the index of the last tile.""" - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - patch_tile_cache(cache_tiles_strategy, tiles, tile_infos) - - assert cache_tiles_strategy._last_tile_index() == 8 - - -def test_last_tile_index_error(cache_tiles_strategy): - """Test `CacheTiles._last_tile_index` raises an error when there is no last tile.""" - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - # don't include last tile - patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) - - with pytest.raises(ValueError): - cache_tiles_strategy._last_tile_index() - - -def test_get_image_tiles(cache_tiles_strategy): - """Test `CacheTiles._get_image_tiles` returns the tiles of a single image.""" - # all tiles of 2 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=2) - # include first tile from next sample - patch_tile_cache(cache_tiles_strategy, tiles[:10], tile_infos[:10]) - - image_tiles, image_tile_infos = cache_tiles_strategy._get_image_tiles() - - assert len(image_tiles) == 9 - assert all(np.array_equal(image_tiles[i], tiles[i]) for i in range(9)) - assert image_tile_infos == tile_infos[:9] diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_file_path_utils.py b/tests/lightning/callbacks/prediction_writer_callback/test_file_path_utils.py deleted file mode 100644 index a98c9cc2..00000000 --- a/tests/lightning/callbacks/prediction_writer_callback/test_file_path_utils.py +++ /dev/null @@ -1,54 +0,0 @@ -from pathlib import Path -from unittest.mock import Mock - -from careamics.config import InferenceConfig -from careamics.dataset import IterablePredDataset, IterableTiledPredDataset -from careamics.lightning.callbacks.prediction_writer_callback.file_path_utils import ( - create_write_file_path, - get_sample_file_path, -) - - -def test_get_sample_file_path_tiled_ds(): - - # Create DS with mock InferenceConfig - src_files = [f"{i}.ext" for i in range(2)] - pred_config = Mock(spec=InferenceConfig) - # attrs used in DS initialization - pred_config.axes = Mock() - pred_config.tile_size = Mock() - pred_config.tile_overlap = Mock() - pred_config.image_means = [Mock()] - pred_config.image_stds = [Mock()] - ds = IterableTiledPredDataset(pred_config, src_files=src_files) - - for i in range(2): - file_path = get_sample_file_path(ds, sample_id=i) - assert file_path == f"{i}.ext" - - -def test_get_sample_file_path_untiled_ds(): - - # Create DS with mock InferenceConfig - src_files = [f"{i}.ext" for i in range(2)] - pred_config = Mock(spec=InferenceConfig) - # attrs used in DS initialization - pred_config.axes = Mock() - pred_config.image_means = [Mock()] - pred_config.image_stds = [Mock()] - ds = IterablePredDataset(pred_config, src_files=src_files) - - for i in range(2): - file_path = get_sample_file_path(ds, sample_id=i) - assert file_path == f"{i}.ext" - - -def test_create_write_file_path(): - dirpath = Path("output_directory") - file_path = Path("input_directory/file_name.in_ext") - write_extension = ".out_ext" - - write_file_path = create_write_file_path( - dirpath=dirpath, file_path=file_path, write_extension=write_extension - ) - assert write_file_path == Path("output_directory/file_name.out_ext") diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py index df6a71dc..48215b75 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py @@ -3,7 +3,7 @@ import os from pathlib import Path from typing import Union -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import numpy as np import pytest @@ -85,7 +85,13 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): ) # create prediction writer callback params - write_strategy = create_write_strategy(write_type="tiff", tiled=True) + write_strategy = create_write_strategy( + write_type="tiff", + tiled=True, + write_filenames=[file_name], + n_samples_per_file=[1], + ) + write_strategy.reset = MagicMock(side_effect=write_strategy.reset) dirpath = tmp_path / "predictions" # create trainer @@ -120,13 +126,16 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): predicted = trainer.predict(model, datamodule=predict_data) predicted_images = convert_outputs(predicted, tiled=True) + # filenames reset after predictions called + write_strategy.reset.assert_called_once() + # assert predicted file exists assert (dirpath / file_name).is_file() # open file save_data = tifffile.imread(dirpath / file_name) # save data has singleton channel axis - np.testing.assert_array_equal(save_data, predicted_images[0][0], verbose=True) + np.testing.assert_array_equal(save_data, predicted_images[0], verbose=True) def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): @@ -163,7 +172,13 @@ def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): ) # create prediction writer callback params - write_strategy = create_write_strategy(write_type="tiff", tiled=False) + write_strategy = create_write_strategy( + write_type="tiff", + tiled=False, + write_filenames=[file_name], + n_samples_per_file=[1], + ) + write_strategy.reset = MagicMock(side_effect=write_strategy.reset) dirpath = tmp_path / "predictions" # create trainer @@ -196,13 +211,16 @@ def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): predicted = trainer.predict(model, datamodule=predict_data) predicted_images = convert_outputs(predicted, tiled=False) + # filenames reset after predictions called + write_strategy.reset.assert_called_once() + # assert predicted file exists assert (dirpath / file_name).is_file() # open file save_data = tifffile.imread(dirpath / file_name) # save data has singleton channel axis - np.testing.assert_array_equal(save_data, predicted_images[0][0], verbose=True) + np.testing.assert_array_equal(save_data, predicted_images[0], verbose=True) def test_initialization(prediction_writer_callback, write_strategy, dirpath): diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py b/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py new file mode 100644 index 00000000..248fa5b7 --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py @@ -0,0 +1,99 @@ +"""Test the utility `SampleCache` class used by the WriteTile classes.""" + +import numpy as np +import pytest + +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + caches, +) + + +def test_add(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + sample_cache.add(samples[0]) + np.testing.assert_array_equal(sample_cache.sample_cache[0], samples[0]) + + +def test_has_all_file_samples_true(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + for i in range(n_samples_per_file[0]): + sample_cache.add(samples[i]) + assert sample_cache.has_all_file_samples() + + +def test_has_all_file_samples_false(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + for i in range(n_samples_per_file[0] - 1): + sample_cache.add(samples[i]) + assert not sample_cache.has_all_file_samples() + + +def test_has_all_file_samples_error(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + # simulate iterating through all files + for sample in samples: + sample_cache.add(sample) + for _ in n_samples_per_file: + sample_cache.pop_file_samples() + + with pytest.raises(ValueError): + sample_cache.has_all_file_samples() + + +def test_pop_file_samples(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + + n = 4 + for i in range(n): + sample_cache.add(samples[i]) + + first_file_n_samples = n_samples_per_file[0] + + # assert popped samples correct + popped_samples = sample_cache.pop_file_samples() + for i in range(first_file_n_samples): + np.testing.assert_array_equal(popped_samples[i], samples[i]) + + # assert remaining sample correct + np.testing.assert_array_equal( + sample_cache.sample_cache[0], samples[first_file_n_samples] + ) + + +def test_pop_file_samples_error(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + sample_cache.add(samples[0]) + + # should raise an error because sample cache does not contain all file samples + with pytest.raises(ValueError): + sample_cache.pop_file_samples() + + +def test_reset(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + sample_cache.add(samples[0]) + + sample_cache.reset() + assert len(sample_cache.sample_cache) == 0 + assert next(sample_cache.n_samples_iter) == n_samples_per_file[0] diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py b/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py new file mode 100644 index 00000000..f919130a --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py @@ -0,0 +1,100 @@ +"""Test the utility `TileCache` class used by the WriteTile classes.""" + +import numpy as np +import pytest + +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + caches, +) + + +def test_add(create_tiles, patch_tile_cache): + """Test `TileCache.add` extends the internal caches as expected.""" + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tile_list, tile_infos = create_tiles(n_samples=1) + + n = 4 + # patch in tiles up to n into tile_cache + patch_tile_cache(tile_cache, tile_list[:n], tile_infos[:n]) + # add remaining tiles to tile_cache + tile_cache.add((np.concatenate(tile_list[n:]), tile_infos[n:])) + + np.testing.assert_array_equal( + np.concatenate(tile_cache.array_cache), np.concatenate(tile_list) + ) + assert tile_cache.tile_info_cache == tile_infos + + +def test_has_last_tile_true(create_tiles, patch_tile_cache): + """ + Test `TileCache.has_last_tile` returns true when there is a last tile. + """ + + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + + patch_tile_cache(tile_cache, tiles, tile_infos) + assert tile_cache.has_last_tile() + + +def test_has_last_tile_false(create_tiles, patch_tile_cache): + """Test `TileCache.has_last_tile` returns false when there is not a last tile.""" + + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + # don't include last tile + patch_tile_cache(tile_cache, tiles[:-1], tile_infos[:-1]) + + assert not tile_cache.has_last_tile() + + +def test_pop_image_tiles(create_tiles, patch_tile_cache): + """ + Test `TileCache.has_last_tile` removes the tiles up until the first "last tile". + """ + + tile_cache = caches.TileCache() + # all tiles of 2 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=2) + # include first tile from next sample + patch_tile_cache(tile_cache, tiles[:10], tile_infos[:10]) + + image_tiles, image_tile_infos = tile_cache.pop_image_tiles() + + # test popped tiles as expected + assert len(image_tiles) == 9 + assert all(np.array_equal(image_tiles[i], tiles[i]) for i in range(9)) + assert image_tile_infos == tile_infos[:9] + + # test tiles remaining in cache as expected + assert len(tile_cache.array_cache) == 1 + assert np.array_equal(tile_cache.array_cache[0], tiles[9]) + assert tile_cache.tile_info_cache[0] == tile_infos[9] + + +def test_pop_image_tiles_error(create_tiles, patch_tile_cache): + """Test `TileCache.pop_image_tiles` raises an error when there is no last tile.""" + + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + # don't include last tile + patch_tile_cache(tile_cache, tiles[:-1], tile_infos[:-1]) + # should raise error if pop_image_tiles is called without last tile present + with pytest.raises(ValueError): + tile_cache.pop_image_tiles() + + +def test_reset(create_tiles, patch_tile_cache): + """Test `TileCache.reset resets the cached tiles as expected.""" + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + patch_tile_cache(tile_cache, tiles, tile_infos) + + tile_cache.reset() + assert len(tile_cache.array_cache) == 0 + assert len(tile_cache.tile_info_cache) == 0 diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_strategy.py new file mode 100644 index 00000000..1f4823c2 --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_strategy.py @@ -0,0 +1,147 @@ +"""Test `WriteImage` class.""" + +from pathlib import Path +from unittest.mock import Mock + +import numpy as np +import pytest +from pytorch_lightning import LightningModule, Trainer +from torch.utils.data import DataLoader + +from careamics.dataset import IterablePredDataset +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + WriteImage, +) + + +@pytest.fixture +def write_image_strategy(write_func) -> WriteImage: + """ + Initialized `WriteImage` class. + + Parameters + ---------- + write_func : WriteFunc + Write function. (Comes from fixture). + + Returns + ------- + WriteImage + Initialized `WriteImage` class. + """ + write_extension = ".ext" + write_func_kwargs = {} + return WriteImage( + write_func=write_func, + write_extension=write_extension, + write_func_kwargs=write_func_kwargs, + write_filenames=None, + n_samples_per_file=None, + ) + + +def test_write_image_init(write_func, write_image_strategy): + """ + Test `WriteImage` initializes as expected. + """ + assert write_image_strategy.write_func is write_func + assert write_image_strategy.write_extension == ".ext" + assert write_image_strategy.write_func_kwargs == {} + + +def test_write_batch(write_image_strategy: WriteImage, ordered_array): + + n_batches = 1 + + prediction = ordered_array((n_batches, 1, 8, 8)) + + batch = prediction + batch_indices = np.arange(n_batches) + batch_idx = 0 + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterablePredDataset) + dataloader_idx = 0 + mock_dl = Mock(spec=DataLoader) + mock_dl.batch_size = 1 + trainer.predict_dataloaders = [mock_dl] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + # call write batch + dirpath = Path("predictions") + write_image_strategy.set_file_data( + write_filenames=["file"], n_samples_per_file=[n_batches] + ) + write_image_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=prediction, + batch_indices=batch_indices, + batch=batch, # contains the last tile + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + dirpath=dirpath, + ) + + # assert write_func is called as expected + # cannot use `assert_called_once_with` because of numpy array + write_image_strategy.write_func.assert_called_once() + assert write_image_strategy.write_func.call_args.kwargs["file_path"] == Path( + "predictions/file.ext" + ) + np.testing.assert_array_equal( + write_image_strategy.write_func.call_args.kwargs["img"], prediction + ) + + +def test_write_batch_raises(write_image_strategy: WriteImage, ordered_array): + """Test write batch raises a ValueError if the filenames have not been set.""" + n_batches = 1 + + prediction = ordered_array((n_batches, 1, 8, 8)) + + batch = prediction + batch_indices = np.arange(n_batches) + batch_idx = 0 + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterablePredDataset) + dataloader_idx = 0 + mock_dl = Mock(spec=DataLoader) + mock_dl.batch_size = 1 + trainer.predict_dataloaders = [mock_dl] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + # call write batch + dirpath = Path("predictions") + + with pytest.raises(ValueError): + # Make sure write_filenames is None + assert write_image_strategy._write_filenames is None + write_image_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=prediction, + batch_indices=batch_indices, + batch=batch, # contains the last tile + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + dirpath=dirpath, + ) + + +def test_reset(write_image_strategy: WriteImage): + """Test WriteImage.reset works as expected""" + write_image_strategy._write_filenames = ["file"] + write_image_strategy.current_file_index = 1 + write_image_strategy.reset() + assert write_image_strategy._write_filenames is None + assert write_image_strategy.current_file_index == 0 diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py deleted file mode 100644 index 04134245..00000000 --- a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Test `WriteImage` class.""" - -from pathlib import Path -from unittest.mock import DEFAULT, Mock, patch - -import numpy as np -import pytest -from pytorch_lightning import LightningModule, Trainer -from torch.utils.data import DataLoader - -from careamics.dataset import IterablePredDataset -from careamics.file_io import WriteFunc -from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( - WriteImage, -) - - -@pytest.fixture -def write_func(): - """Mock `WriteFunc`.""" - return Mock(spec=WriteFunc) - - -@pytest.fixture -def write_image_strategy(write_func) -> WriteImage: - """ - Initialized `CacheTiles` class. - - Parameters - ---------- - write_func : WriteFunc - Write function. (Comes from fixture). - - Returns - ------- - CacheTiles - Initialized `CacheTiles` class. - """ - write_extension = ".ext" - write_func_kwargs = {} - return WriteImage( - write_func=write_func, - write_extension=write_extension, - write_func_kwargs=write_func_kwargs, - ) - - -def test_cache_tiles_init(write_func, write_image_strategy): - """ - Test `WriteImage` initializes as expected. - """ - assert write_image_strategy.write_func is write_func - assert write_image_strategy.write_extension == ".ext" - assert write_image_strategy.write_func_kwargs == {} - - -def test_write_batch(write_image_strategy, ordered_array): - - n_batches = 1 - - prediction = ordered_array((n_batches, 1, 8, 8)) - - batch = prediction - batch_indices = np.arange(n_batches) - batch_idx = 0 - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - mock_dataset = Mock(spec=IterablePredDataset) - dataloader_idx = 0 - mock_dl = Mock(spec=DataLoader) - mock_dl.batch_size = 1 - trainer.predict_dataloaders = [mock_dl] - trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset - - # These functions have their own unit tests, - # so they do not need to be tested again here. - # This is a unit test to isolate functionality of `write_batch.` - with patch.multiple( - "careamics.lightning.callbacks.prediction_writer_callback.write_strategy", - get_sample_file_path=DEFAULT, - create_write_file_path=DEFAULT, - ) as values: - - # mocked functions - mock_get_sample_file_path = values["get_sample_file_path"] - mock_create_write_file_path = values["create_write_file_path"] - - # assign mock functions return value - in_file_path = Path("in_dir/file_path.ext") - out_file_path = Path("out_dir/file_path.out_ext") - mock_get_sample_file_path.return_value = in_file_path - mock_create_write_file_path.return_value = out_file_path - - # call write batch - dirpath = "predictions" - write_image_strategy.write_batch( - trainer=trainer, - pl_module=Mock(spec=LightningModule), - prediction=prediction, - batch_indices=batch_indices, - batch=batch, # contains the last tile - batch_idx=batch_idx, - dataloader_idx=dataloader_idx, - dirpath=dirpath, - ) - - # assert create_write_file_path is called as expected (TODO: necessary ?) - mock_create_write_file_path.assert_called_with( - dirpath=dirpath, - file_path=in_file_path, - write_extension=write_image_strategy.write_extension, - ) - # assert write_func is called as expectedÏ - # cannot use `assert_called_once_with` because of numpy array - write_image_strategy.write_func.assert_called_once() - assert ( - write_image_strategy.write_func.call_args.kwargs["file_path"] - == out_file_path - ) - assert np.array_equal( - write_image_strategy.write_func.call_args.kwargs["img"], prediction[0] - ) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py index f00cffde..72fc6e35 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py @@ -8,8 +8,8 @@ from careamics.file_io.write import write_tiff from careamics.lightning.callbacks.prediction_writer_callback import ( - CacheTiles, WriteImage, + WriteTiles, create_write_strategy, select_write_extension, select_write_func, @@ -25,7 +25,7 @@ def test_create_write_strategy_tiff_tiled(): """Test write strategy creation for tiled tiff.""" write_strategy = create_write_strategy(write_type="tiff", tiled=True) - assert isinstance(write_strategy, CacheTiles) + assert isinstance(write_strategy, WriteTiles) assert write_strategy.write_func is write_tiff assert write_strategy.write_extension == ".tiff" assert write_strategy.write_func_kwargs == {} @@ -46,7 +46,7 @@ def test_create_write_strategy_custom_tiled(): write_strategy = create_write_strategy( write_type="custom", tiled=True, write_func=save_numpy, write_extension=".npy" ) - assert isinstance(write_strategy, CacheTiles) + assert isinstance(write_strategy, WriteTiles) assert write_strategy.write_func is save_numpy assert write_strategy.write_extension == ".npy" assert write_strategy.write_func_kwargs == {} diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_tiles_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_tiles_strategy.py new file mode 100644 index 00000000..5c32a728 --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_tiles_strategy.py @@ -0,0 +1,294 @@ +"""Test `CacheTiles` class.""" + +from pathlib import Path +from unittest.mock import Mock + +import numpy as np +import pytest +import tifffile +from pytorch_lightning import LightningModule, Trainer +from torch.utils.data import DataLoader + +from careamics.dataset import IterableTiledPredDataset +from careamics.file_io.write import write_tiff +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + WriteTiles, +) +from careamics.prediction_utils import stitch_prediction + + +@pytest.fixture +def write_tiles_strategy() -> WriteTiles: + """ + Initialized `WriteTiles` class. + + Parameters + ---------- + write_func : WriteFunc + Write function. (Comes from fixture). + + Returns + ------- + CacheTiles + Initialized `CacheTiles` class. + """ + write_extension = ".tiff" + write_func_kwargs = {} + return WriteTiles( + write_func=write_tiff, + write_extension=write_extension, + write_func_kwargs=write_func_kwargs, + write_filenames=None, + n_samples_per_file=None, + ) + + +def test_write_batch_no_last_tile( + tmp_path, + write_tiles_strategy: WriteTiles, + create_tiles, + patch_tile_cache, +): + """ + Test behaviour when the last tile of a sample is not present. + + `WriteTiles.write_batch` should cache the tiles in the given batch. + """ + # all tiles of 1 samples with 9 tiles + n_samples = 1 + tiles, tile_infos = create_tiles(n_samples=n_samples) + + # simulate adding a batch that will not contain the last tile + n_tiles = 4 + batch_size = 2 + patch_tile_cache( + write_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] + ) + next_batch = ( + np.concatenate(tiles[n_tiles : n_tiles + batch_size]), + tile_infos[n_tiles : n_tiles + batch_size], + ) + + # mock trainer and datasets (difficult to set up true classes) + trainer: Trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterableTiledPredDataset) + dataloader_idx = 0 + trainer.predict_dataloaders = [Mock(spec=DataLoader)] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + write_tiles_strategy.set_file_data( + write_filenames=["file_1.tif"], n_samples_per_file=[n_samples] + ) + write_tiles_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=next_batch, + batch_indices=Mock(), + batch=next_batch, # does not contain the last tile + batch_idx=3, + dataloader_idx=dataloader_idx, + dirpath=tmp_path / "predictions", + ) + + extended_tiles = tiles[: n_tiles + batch_size] + extended_tile_infos = tile_infos[: n_tiles + batch_size] + + # assert tiles and tile infos are caches + for i in range(n_tiles + batch_size): + np.testing.assert_array_equal( + extended_tiles[i], write_tiles_strategy.tile_cache.array_cache[i] + ) + assert extended_tile_infos == write_tiles_strategy.tile_cache.tile_info_cache + assert len(write_tiles_strategy.sample_cache.sample_cache) == 0 + + +def test_write_batch_has_last_tile_no_last_sample( + tmp_path, + write_tiles_strategy: WriteTiles, + create_tiles, + patch_tile_cache, +): + """ + Test behaviour when the last tile of a sample is present, but not the last sample. + + `WriteTiles.write_batch` should cache the resulting stitched sample. + """ + # all tiles of 2 samples with 9 tiles each + n_samples = 2 + tiles, tile_infos = create_tiles(n_samples=n_samples) + + # simulate adding a batch that will not contain the last sample + n_tiles = 8 + batch_size = 2 + patch_tile_cache( + write_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] + ) + # (ext batch will include the last tile of the first sample + next_batch = ( + np.concatenate(tiles[n_tiles : n_tiles + batch_size]), + tile_infos[n_tiles : n_tiles + batch_size], + ) + + # mock trainer and datasets (difficult to set up true classes) + trainer: Trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterableTiledPredDataset) + dataloader_idx = 0 + trainer.predict_dataloaders = [Mock(spec=DataLoader)] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + write_tiles_strategy.set_file_data( + write_filenames=["file_1.tif"], n_samples_per_file=[n_samples] + ) + write_tiles_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=next_batch, + batch_indices=Mock(), + batch=next_batch, # contains the last tile of the first sample + batch_idx=3, + dataloader_idx=dataloader_idx, + dirpath=tmp_path / "predictions", + ) + + stitched_sample1 = stitch_prediction(tiles[:9], tile_infos[:9])[0] + + # assert first tile from the second sample is in the tile cache + np.testing.assert_array_equal( + tiles[9], write_tiles_strategy.tile_cache.array_cache[0] + ) + assert [tile_infos[9]] == write_tiles_strategy.tile_cache.tile_info_cache + # assert the stitched sample is saved in the sample cache + np.testing.assert_array_equal( + stitched_sample1, write_tiles_strategy.sample_cache.sample_cache[0] + ) + + +def test_write_batch_has_last_sample( + tmp_path, + write_tiles_strategy: WriteTiles, + create_tiles, + patch_tile_cache, +): + """ + Test behaviour when the last sample of a file is present. + + `WriteTiles.write_batch` should write the resulting set of samples to disk. + """ + """ + Test behaviour when the last tile of a sample is present, but not the last sample. + + `WriteTiles.write_batch` should cache the resulting stitched sample. + """ + + # all tiles of 2 samples with 9 tiles each + n_samples = 2 + tiles, tile_infos = create_tiles(n_samples=n_samples) + + stitched_samples = stitch_prediction(tiles, tile_infos) + file_data = np.concatenate(stitched_samples) + + # simulate adding a batch that will not contain the last sample + batch_size = 2 + n_tiles = 16 + write_tiles_strategy.set_file_data( + write_filenames=["file_1.tif"], n_samples_per_file=[n_samples] + ) + patch_tile_cache( + write_tiles_strategy.tile_cache, tiles[9:n_tiles], tile_infos[9:n_tiles] + ) + # also patch_sample + write_tiles_strategy.sample_cache.sample_cache.append(stitched_samples[0]) + # (ext batch will include the last tile of the first sample + next_batch = ( + np.concatenate(tiles[n_tiles : n_tiles + batch_size]), + tile_infos[n_tiles : n_tiles + batch_size], + ) + + # mock trainer and datasets (difficult to set up true classes) + trainer: Trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterableTiledPredDataset) + dataloader_idx = 0 + trainer.predict_dataloaders = [Mock(spec=DataLoader)] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + # normally output directory creation is handled by PredictionWriterCallback + (tmp_path / "predictions").mkdir() + write_tiles_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=next_batch, + batch_indices=Mock(), + batch=next_batch, # contains the last tile of the last sample + batch_idx=3, + dataloader_idx=dataloader_idx, + dirpath=tmp_path / "predictions", + ) + + # assert caches are now empty + assert len(write_tiles_strategy.tile_cache.array_cache) == 0 + assert len(write_tiles_strategy.tile_cache.tile_info_cache) == 0 + assert len(write_tiles_strategy.sample_cache.sample_cache) == 0 + + assert (tmp_path / "predictions" / "file_1.tiff").is_file() + + load_file_data = tifffile.imread(tmp_path / "predictions" / "file_1.tiff") + np.testing.assert_array_equal(load_file_data, file_data) + + +def test_write_batch_raises( + write_tiles_strategy: WriteTiles, create_tiles, patch_tile_cache +): + """Test write batch raises a ValueError if the filenames have not been set.""" + # all tiles of 2 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=2) + + # simulate adding a batch that will contain the last tile + n_tiles = 8 + batch_size = 2 + patch_tile_cache( + write_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] + ) + next_batch = ( + np.concatenate(tiles[n_tiles : n_tiles + batch_size]), + tile_infos[n_tiles : n_tiles + batch_size], + ) + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterableTiledPredDataset) + dataloader_idx = 0 + trainer.predict_dataloaders = [Mock(spec=DataLoader)] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + with pytest.raises(ValueError): + assert write_tiles_strategy._write_filenames is None + + # call write batch + dirpath = Path("predictions") + write_tiles_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=next_batch, + batch_indices=Mock(), + batch=next_batch, # contains the last tile + batch_idx=3, + dataloader_idx=dataloader_idx, + dirpath=dirpath, + ) + + +def test_reset(write_tiles_strategy: WriteTiles, create_tiles, patch_tile_cache): + """Test CacheTiles.reset works as expected""" + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + # don't include last tile + patch_tile_cache(write_tiles_strategy.tile_cache, tiles[:-1], tile_infos[:-1]) + + write_tiles_strategy.set_file_data(write_filenames=["file"], n_samples_per_file=[1]) + write_tiles_strategy.reset() + + assert write_tiles_strategy._write_filenames is None + assert write_tiles_strategy._filename_iter is None