From 091a518306759fcaa2344a5d989432870185e9e9 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 11:52:59 +0200 Subject: [PATCH 01/38] feat(write strategy): add write_filenmaes attribute --- .../prediction_writer_callback/write_strategy.py | 16 ++++++++++++++++ .../write_strategy_factory.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py index 9b298da1..c0316ea2 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py @@ -65,6 +65,8 @@ class CacheTiles(WriteStrategy): ---------- write_func : WriteFunc Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} @@ -74,6 +76,8 @@ class CacheTiles(WriteStrategy): ---------- write_func : WriteFunc Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} @@ -87,6 +91,7 @@ class CacheTiles(WriteStrategy): def __init__( self, write_func: WriteFunc, + write_filenames: Optional[list[str]], write_extension: str, write_func_kwargs: dict[str, Any], ) -> None: @@ -100,6 +105,8 @@ def __init__( ---------- write_func : WriteFunc Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} @@ -108,6 +115,7 @@ def __init__( super().__init__() self.write_func: WriteFunc = write_func + self.write_filenames: Optional[list[str]] = write_filenames self.write_extension: str = write_extension self.write_func_kwargs: dict[str, Any] = write_func_kwargs @@ -301,6 +309,8 @@ class WriteImage(WriteStrategy): ---------- write_func : WriteFunc Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} @@ -310,6 +320,8 @@ class WriteImage(WriteStrategy): ---------- write_func : WriteFunc Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} @@ -319,6 +331,7 @@ class WriteImage(WriteStrategy): def __init__( self, write_func: WriteFunc, + write_filenames: Optional[list[str]], write_extension: str, write_func_kwargs: dict[str, Any], ) -> None: @@ -329,6 +342,8 @@ def __init__( ---------- write_func : WriteFunc Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} @@ -337,6 +352,7 @@ def __init__( super().__init__() self.write_func: WriteFunc = write_func + self.write_filenames: Optional[list[str]] = write_filenames self.write_extension: str = write_extension self.write_func_kwargs: dict[str, Any] = write_func_kwargs diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py index a9eda4a4..dd35016b 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py @@ -12,6 +12,7 @@ def create_write_strategy( write_type: SupportedWriteType, tiled: bool, write_func: Optional[WriteFunc] = None, + write_filenames: Optional[list[str]] = None, write_extension: Optional[str] = None, write_func_kwargs: Optional[dict[str, Any]] = None, ) -> WriteStrategy: @@ -27,6 +28,8 @@ def create_write_strategy( write_func : WriteFunc, optional If a known `write_type` is selected this argument is ignored. For a custom `write_type` a function to save the data must be passed. See notes below. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. write_extension : str, optional If a known `write_type` is selected this argument is ignored. For a custom `write_type` an extension to save the data with must be passed. @@ -61,6 +64,7 @@ def create_write_strategy( ) write_strategy = WriteImage( write_func=write_func, + write_filenames=write_filenames, write_extension=write_extension, write_func_kwargs=write_func_kwargs, ) @@ -69,6 +73,7 @@ def create_write_strategy( write_strategy = _create_tiled_write_strategy( write_type=write_type, write_func=write_func, + write_filenames=write_filenames, write_extension=write_extension, write_func_kwargs=write_func_kwargs, ) @@ -79,6 +84,7 @@ def create_write_strategy( def _create_tiled_write_strategy( write_type: SupportedWriteType, write_func: Optional[WriteFunc], + write_filenames: Optional[list[str]], write_extension: Optional[str], write_func_kwargs: dict[str, Any], ) -> WriteStrategy: @@ -95,6 +101,8 @@ def _create_tiled_write_strategy( write_func : WriteFunc, optional If a known `write_type` is selected this argument is ignored. For a custom `write_type` a function to save the data must be passed. See notes below. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. write_extension : str, optional If a known `write_type` is selected this argument is ignored. For a custom `write_type` an extension to save the data with must be passed. @@ -124,6 +132,7 @@ def _create_tiled_write_strategy( ) return CacheTiles( write_func=write_func, + write_filenames=write_filenames, write_extension=write_extension, write_func_kwargs=write_func_kwargs, ) From c4213a2fe4091494459fc8e852a25d41327a4b33 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 12:27:07 +0200 Subject: [PATCH 02/38] feat(write_strategy): add reset method --- .../prediction_writer_callback.py | 21 +++++++++++++++- .../write_strategy.py | 25 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py b/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py index 8a73c12e..da945f11 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py @@ -127,7 +127,7 @@ def from_write_func_params( ) return cls(write_strategy=write_strategy, dirpath=dirpath) - def _init_dirpath(self, dirpath): + def _init_dirpath(self, dirpath: Union[Path, str]): """ Initialize directory path. Should only be called from `__init__`. @@ -161,6 +161,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non Stage of training e.g. 'predict', 'fit', 'validate'. """ super().setup(trainer, pl_module, stage) + # TODO: move directory creation to hook on_predict_epoch_start if stage == "predict": # make prediction output directory logger.info("Making prediction output directory.") @@ -231,3 +232,21 @@ def write_on_batch_end( dataloader_idx=dataloader_idx, dirpath=self.dirpath, ) + + def on_predict_epoch_end( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """ + Lightning hook called at the end of prediction. + + Resets write_strategy. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer. + pl_module : LightningModule + PyTorch Lightning module. + """ + # reset write_strategy to prevent bugs if trainer.predict is called twice. + self.write_strategy.reset() diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py index c0316ea2..9442a54e 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py @@ -53,6 +53,13 @@ def write_batch( Path to directory to save predictions to. """ + def reset(self) -> None: + """ + Reset internal attributes of a `WriteStrategy` instance. + + This is to prevent bugs if a `WriteStrategy` instance is used twice. + """ + class CacheTiles(WriteStrategy): """ @@ -206,6 +213,16 @@ def write_batch( file_path=file_path, img=prediction_image[0], **self.write_func_kwargs ) + def reset(self) -> None: + """ + Reset the internal attributes. + + Attributes reset are: `write_filenames`, `tile_cache` and `tile_info_cache`. + """ + self.write_filenames = None + self.tile_cache = [] + self.tile_info_cache = [] + def _has_last_tile(self) -> bool: """ Whether a last tile is contained in the cached tiles. @@ -300,6 +317,10 @@ def write_batch( """ raise NotImplementedError + def reset(self) -> None: + """Reset internal attributes.""" + raise NotImplementedError + class WriteImage(WriteStrategy): """ @@ -412,3 +433,7 @@ def write_batch( self.write_func( file_path=file_path, img=prediction_image, **self.write_func_kwargs ) + + def reset(self) -> None: + """Reset internal attributes. Resets the `write_filename`.""" + self.write_filenames = None From dedfac6a123e429bd9b78c26be62c43d9b84ee96 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 12:48:43 +0200 Subject: [PATCH 03/38] fix(write on batch end): remove no longer needed ds instance type check --- .../prediction_writer_callback.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py b/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py index da945f11..030d6f61 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py @@ -7,7 +7,6 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import BasePredictionWriter -from torch.utils.data import DataLoader from careamics.dataset import ( IterablePredDataset, @@ -203,25 +202,6 @@ def write_on_batch_end( if not self.writing_predictions: return - dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders - dataloader: DataLoader = ( - dataloaders[dataloader_idx] - if isinstance(dataloaders, list) - else dataloaders - ) - dataset: ValidPredDatasets = dataloader.dataset - if not ( - isinstance(dataset, IterablePredDataset) - or isinstance(dataset, IterableTiledPredDataset) - ): - # Note: Error will be raised before here from the source type - # This is for extra redundancy of errors. - raise TypeError( - "Prediction dataset has to be `IterableTiledPredDataset` or " - "`IterablePredDataset`. Cannot be `InMemoryPredDataset` because " - "filenames are taken from the original file." - ) - self.write_strategy.write_batch( trainer=trainer, pl_module=pl_module, From 14733cafc251d232c89215436aafba4e415d6042 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 13:12:33 +0200 Subject: [PATCH 04/38] feat(write_strategy): add current file index attr --- .../write_strategy.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py index 9442a54e..3979b7d9 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py @@ -93,6 +93,8 @@ class CacheTiles(WriteStrategy): Tiles cached for stitching prediction. tile_info_cache : list of TileInformation Cached tile information for stitching prediction. + current_file_index : int + Index of current file, increments every time a file is written. """ def __init__( @@ -130,6 +132,8 @@ def __init__( self.tile_cache: list[NDArray] = [] self.tile_info_cache: list[TileInformation] = [] + self.current_file_index = 0 + @property def last_tiles(self) -> list[bool]: """ @@ -212,16 +216,19 @@ def write_batch( self.write_func( file_path=file_path, img=prediction_image[0], **self.write_func_kwargs ) + self.current_file_index += 1 def reset(self) -> None: """ Reset the internal attributes. - Attributes reset are: `write_filenames`, `tile_cache` and `tile_info_cache`. + Attributes reset are: `write_filenames`, `tile_cache`, `tile_info_cache` and + `current_file_index`. """ self.write_filenames = None self.tile_cache = [] self.tile_info_cache = [] + self.current_file_index = 0 def _has_last_tile(self) -> bool: """ @@ -347,6 +354,8 @@ class WriteImage(WriteStrategy): Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} Extra kwargs to pass to `write_func`. + current_file_index : int + Index of current file, increments every time a file is written. """ def __init__( @@ -377,6 +386,8 @@ def __init__( self.write_extension: str = write_extension self.write_func_kwargs: dict[str, Any] = write_func_kwargs + self.current_file_index: int = 0 + def write_batch( self, trainer: Trainer, @@ -433,7 +444,13 @@ def write_batch( self.write_func( file_path=file_path, img=prediction_image, **self.write_func_kwargs ) + self.current_file_index += 1 def reset(self) -> None: - """Reset internal attributes. Resets the `write_filename`.""" + """ + Reset internal attributes. + + Resets the `write_filename` and `current_file_index`. + """ self.write_filenames = None + self.current_file_index = 0 From 99459b7ce49ec6ce304b161981604f1584c37cf0 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 13:23:46 +0200 Subject: [PATCH 05/38] feat: update write_batch method to use write_filenames attribute --- .../write_strategy.py | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py index 3979b7d9..45235af0 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py @@ -13,8 +13,6 @@ from careamics.file_io import WriteFunc from careamics.prediction_utils import stitch_prediction_single -from .file_path_utils import create_write_file_path, get_sample_file_path - class WriteStrategy(Protocol): """Protocol for write strategy classes.""" @@ -178,7 +176,14 @@ def write_batch( Dataloader index. dirpath : Path Path to directory to save predictions to. + + Raises + ------ + ValueError + If `write_filenames` attribute is `None`. """ + if self.write_filenames is None: + raise ValueError("`write_filenames` attribute has not been set.") dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders dataloader: DataLoader = ( dataloaders[dataloader_idx] @@ -206,13 +211,8 @@ def write_batch( ) # write prediction - sample_id = tile_infos[0].sample_id # need this to select correct file name - input_file_path = get_sample_file_path(dataset=dataset, sample_id=sample_id) - file_path = create_write_file_path( - dirpath=dirpath, - file_path=input_file_path, - write_extension=self.write_extension, - ) + file_name = self.write_filenames[self.current_file_index] + file_path = (dirpath / file_name).with_suffix(self.write_extension) self.write_func( file_path=file_path, img=prediction_image[0], **self.write_func_kwargs ) @@ -425,32 +425,33 @@ def write_batch( ------ TypeError If trainer prediction dataset is not `IterablePredDataset`. + ValueError + If `write_filenames` attribute is `None`. """ + if self.write_filenames is None: + raise ValueError("`write_filenames` attribute has not been set.") + dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls ds: IterablePredDataset = dl.dataset if not isinstance(ds, IterablePredDataset): + # TODO: change to warning raise TypeError("Prediction dataset is not `IterablePredDataset`.") - for i in range(prediction.shape[0]): - prediction_image = prediction[0] - sample_id = batch_idx * dl.batch_size + i - input_file_path = get_sample_file_path(dataset=ds, sample_id=sample_id) - file_path = create_write_file_path( - dirpath=dirpath, - file_path=input_file_path, - write_extension=self.write_extension, - ) - self.write_func( - file_path=file_path, img=prediction_image, **self.write_func_kwargs - ) - self.current_file_index += 1 + # for i in range(prediction.shape[0]): + # prediction_image = prediction[0] + # sample_id = batch_idx * dl.batch_size + i + + file_name = self.write_filenames[self.current_file_index] + file_path = (dirpath / file_name).with_suffix(self.write_extension) + self.write_func(file_path=file_path, img=prediction, **self.write_func_kwargs) + self.current_file_index += 1 def reset(self) -> None: """ Reset internal attributes. - Resets the `write_filename` and `current_file_index`. + Resets the `write_filenames` and `current_file_index` attributes. """ self.write_filenames = None self.current_file_index = 0 From 3a1f031d3a982e9518e528146cb38f8acb7da952 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 13:27:31 +0200 Subject: [PATCH 06/38] feat(write prediction): remove no-longer used file_path_utils.py --- .../file_path_utils.py | 56 ------------------- 1 file changed, 56 deletions(-) delete mode 100644 src/careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py b/src/careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py deleted file mode 100644 index 9da2ba6b..00000000 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Module containing file path utilities for `WriteStrategy` to use.""" - -from pathlib import Path -from typing import Union - -from careamics.dataset import IterablePredDataset, IterableTiledPredDataset - - -# TODO: move to datasets package ? -def get_sample_file_path( - dataset: Union[IterableTiledPredDataset, IterablePredDataset], sample_id: int -) -> Path: - """ - Get the file path for a particular sample. - - Parameters - ---------- - dataset : IterableTiledPredDataset or IterablePredDataset - Dataset. - sample_id : int - Sample ID, the index of the file in the dataset `dataset`. - - Returns - ------- - Path - The file path corresponding to the sample with the ID `sample_id`. - """ - return dataset.data_files[sample_id] - - -def create_write_file_path( - dirpath: Path, file_path: Path, write_extension: str -) -> Path: - """ - Create the file name for the output file. - - Takes the original file path, changes the directory to `dirpath` and changes - the extension to `write_extension`. - - Parameters - ---------- - dirpath : pathlib.Path - The output directory to write file to. - file_path : pathlib.Path - The original file path. - write_extension : str - The extension that output files should have. - - Returns - ------- - Path - The output file path. - """ - file_name = Path(file_path.stem).with_suffix(write_extension) - file_path = dirpath / file_name - return file_path From a82ed40e24188c51366dc1898806838502e29c1f Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 13:48:31 +0200 Subject: [PATCH 07/38] test: remove test_file_path_utils.py --- .../test_file_path_utils.py | 54 ------------------- 1 file changed, 54 deletions(-) delete mode 100644 tests/lightning/callbacks/prediction_writer_callback/test_file_path_utils.py diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_file_path_utils.py b/tests/lightning/callbacks/prediction_writer_callback/test_file_path_utils.py deleted file mode 100644 index a98c9cc2..00000000 --- a/tests/lightning/callbacks/prediction_writer_callback/test_file_path_utils.py +++ /dev/null @@ -1,54 +0,0 @@ -from pathlib import Path -from unittest.mock import Mock - -from careamics.config import InferenceConfig -from careamics.dataset import IterablePredDataset, IterableTiledPredDataset -from careamics.lightning.callbacks.prediction_writer_callback.file_path_utils import ( - create_write_file_path, - get_sample_file_path, -) - - -def test_get_sample_file_path_tiled_ds(): - - # Create DS with mock InferenceConfig - src_files = [f"{i}.ext" for i in range(2)] - pred_config = Mock(spec=InferenceConfig) - # attrs used in DS initialization - pred_config.axes = Mock() - pred_config.tile_size = Mock() - pred_config.tile_overlap = Mock() - pred_config.image_means = [Mock()] - pred_config.image_stds = [Mock()] - ds = IterableTiledPredDataset(pred_config, src_files=src_files) - - for i in range(2): - file_path = get_sample_file_path(ds, sample_id=i) - assert file_path == f"{i}.ext" - - -def test_get_sample_file_path_untiled_ds(): - - # Create DS with mock InferenceConfig - src_files = [f"{i}.ext" for i in range(2)] - pred_config = Mock(spec=InferenceConfig) - # attrs used in DS initialization - pred_config.axes = Mock() - pred_config.image_means = [Mock()] - pred_config.image_stds = [Mock()] - ds = IterablePredDataset(pred_config, src_files=src_files) - - for i in range(2): - file_path = get_sample_file_path(ds, sample_id=i) - assert file_path == f"{i}.ext" - - -def test_create_write_file_path(): - dirpath = Path("output_directory") - file_path = Path("input_directory/file_name.in_ext") - write_extension = ".out_ext" - - write_file_path = create_write_file_path( - dirpath=dirpath, file_path=file_path, write_extension=write_extension - ) - assert write_file_path == Path("output_directory/file_name.out_ext") From 9ebecc17a5c8b203ef31b9ecf10fd61bd100d5f8 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 13:49:36 +0200 Subject: [PATCH 08/38] test: update CacheTile initialisation --- .../test_cache_tiles_write_strategy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index da58b4bc..f6edc794 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -95,6 +95,7 @@ def cache_tiles_strategy(write_func) -> CacheTiles: write_func_kwargs = {} return CacheTiles( write_func=write_func, + write_filenames=None, write_extension=write_extension, write_func_kwargs=write_func_kwargs, ) @@ -109,6 +110,8 @@ def test_cache_tiles_init(write_func, cache_tiles_strategy): assert cache_tiles_strategy.write_func_kwargs == {} assert cache_tiles_strategy.tile_cache == [] assert cache_tiles_strategy.tile_info_cache == [] + assert cache_tiles_strategy.write_filenames is None + assert cache_tiles_strategy.current_file_index == 0 def test_last_tiles(cache_tiles_strategy): From 42f01cde16e466c025aa0ca80c37e8c96df6697d Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 14:13:20 +0200 Subject: [PATCH 09/38] test(CacheTiles): update - set write_filenames --- .../test_cache_tiles_write_strategy.py | 34 ++++++------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index f6edc794..086afac2 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -1,7 +1,7 @@ """Test `CacheTiles` class.""" from pathlib import Path -from unittest.mock import DEFAULT, Mock, patch +from unittest.mock import Mock, patch import numpy as np import pytest @@ -151,6 +151,7 @@ def test_write_batch_no_last_tile(cache_tiles_strategy): trainer.predict_dataloaders = [Mock(spec=DataLoader)] trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + cache_tiles_strategy.write_filenames = ["file_1"] cache_tiles_strategy.write_batch( trainer=trainer, pl_module=Mock(spec=LightningModule), @@ -204,27 +205,18 @@ def test_write_batch_last_tile(cache_tiles_strategy): # These functions have their own unit tests, # so they do not need to be tested again here. # This is a unit test to isolate functionality of `write_batch.` - with patch.multiple( - "careamics.lightning.callbacks.prediction_writer_callback.write_strategy", - stitch_prediction_single=DEFAULT, - get_sample_file_path=DEFAULT, - create_write_file_path=DEFAULT, - ) as values: - - # mocked functions - mock_stitch_prediction_single = values["stitch_prediction_single"] - mock_get_sample_file_path = values["get_sample_file_path"] - mock_create_write_file_path = values["create_write_file_path"] + with patch( + "careamics.lightning.callbacks.prediction_writer_callback.write_strategy" + + ".stitch_prediction_single", + ) as mock_stitch_prediction_single: prediction_image = [Mock()] - in_file_path = Path("in_dir/file_path.ext") - out_file_path = Path("out_dir/file_path.in_ext") + file_name = "file" mock_stitch_prediction_single.return_value = prediction_image - mock_get_sample_file_path.return_value = in_file_path - mock_create_write_file_path.return_value = out_file_path # call write batch - dirpath = "predictions" + dirpath = Path("predictions") + cache_tiles_strategy.write_filenames = [file_name] cache_tiles_strategy.write_batch( trainer=trainer, pl_module=Mock(spec=LightningModule), @@ -236,15 +228,9 @@ def test_write_batch_last_tile(cache_tiles_strategy): dirpath=dirpath, ) - # assert create_write_file_path is called as expected (TODO: necessary ?) - mock_create_write_file_path.assert_called_once_with( - dirpath=dirpath, - file_path=in_file_path, - write_extension=cache_tiles_strategy.write_extension, - ) # assert write_func is called as expected cache_tiles_strategy.write_func.assert_called_once_with( - file_path=out_file_path, img=prediction_image[0], **{} + file_path=Path("predictions/file.ext"), img=prediction_image[0], **{} ) # Tile of the next image (should remain in the cache) From eb82fb31238b2b48471aa1644bd96205c8beb318 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 14:26:32 +0200 Subject: [PATCH 10/38] test: update write_image_strategy tests - write_filenames attribute --- .../test_write_image_write_strategy.py | 74 +++++++------------ 1 file changed, 25 insertions(+), 49 deletions(-) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py index 04134245..98748dae 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py @@ -1,7 +1,7 @@ """Test `WriteImage` class.""" from pathlib import Path -from unittest.mock import DEFAULT, Mock, patch +from unittest.mock import Mock import numpy as np import pytest @@ -40,6 +40,7 @@ def write_image_strategy(write_func) -> WriteImage: write_func_kwargs = {} return WriteImage( write_func=write_func, + write_filenames=None, write_extension=write_extension, write_func_kwargs=write_func_kwargs, ) @@ -76,51 +77,26 @@ def test_write_batch(write_image_strategy, ordered_array): trainer.predict_dataloaders = [mock_dl] trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset - # These functions have their own unit tests, - # so they do not need to be tested again here. - # This is a unit test to isolate functionality of `write_batch.` - with patch.multiple( - "careamics.lightning.callbacks.prediction_writer_callback.write_strategy", - get_sample_file_path=DEFAULT, - create_write_file_path=DEFAULT, - ) as values: - - # mocked functions - mock_get_sample_file_path = values["get_sample_file_path"] - mock_create_write_file_path = values["create_write_file_path"] - - # assign mock functions return value - in_file_path = Path("in_dir/file_path.ext") - out_file_path = Path("out_dir/file_path.out_ext") - mock_get_sample_file_path.return_value = in_file_path - mock_create_write_file_path.return_value = out_file_path - - # call write batch - dirpath = "predictions" - write_image_strategy.write_batch( - trainer=trainer, - pl_module=Mock(spec=LightningModule), - prediction=prediction, - batch_indices=batch_indices, - batch=batch, # contains the last tile - batch_idx=batch_idx, - dataloader_idx=dataloader_idx, - dirpath=dirpath, - ) - - # assert create_write_file_path is called as expected (TODO: necessary ?) - mock_create_write_file_path.assert_called_with( - dirpath=dirpath, - file_path=in_file_path, - write_extension=write_image_strategy.write_extension, - ) - # assert write_func is called as expectedÏ - # cannot use `assert_called_once_with` because of numpy array - write_image_strategy.write_func.assert_called_once() - assert ( - write_image_strategy.write_func.call_args.kwargs["file_path"] - == out_file_path - ) - assert np.array_equal( - write_image_strategy.write_func.call_args.kwargs["img"], prediction[0] - ) + # call write batch + dirpath = Path("predictions") + write_image_strategy.write_filenames = ["file"] + write_image_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=prediction, + batch_indices=batch_indices, + batch=batch, # contains the last tile + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + dirpath=dirpath, + ) + + # assert write_func is called as expected + # cannot use `assert_called_once_with` because of numpy array + write_image_strategy.write_func.assert_called_once() + assert write_image_strategy.write_func.call_args.kwargs["file_path"] == Path( + "predictions/file.ext" + ) + np.testing.assert_array_equal( + write_image_strategy.write_func.call_args.kwargs["img"], prediction + ) From a7cfd5c845adc546fe8577a3cc0cd14540ac49ec Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 14:28:12 +0200 Subject: [PATCH 11/38] test: update prediction writer smoke tests - write_filenames attr --- .../test_prediction_writer_callback.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py index df6a71dc..314c3c7d 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py @@ -85,7 +85,9 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): ) # create prediction writer callback params - write_strategy = create_write_strategy(write_type="tiff", tiled=True) + write_strategy = create_write_strategy( + write_type="tiff", tiled=True, write_filenames=[file_name] + ) dirpath = tmp_path / "predictions" # create trainer @@ -120,6 +122,8 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): predicted = trainer.predict(model, datamodule=predict_data) predicted_images = convert_outputs(predicted, tiled=True) + # TODO: assert filenames reset after trainer.predict is called + # assert predicted file exists assert (dirpath / file_name).is_file() @@ -163,7 +167,9 @@ def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): ) # create prediction writer callback params - write_strategy = create_write_strategy(write_type="tiff", tiled=False) + write_strategy = create_write_strategy( + write_type="tiff", tiled=False, write_filenames=[file_name] + ) dirpath = tmp_path / "predictions" # create trainer @@ -202,7 +208,7 @@ def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): # open file save_data = tifffile.imread(dirpath / file_name) # save data has singleton channel axis - np.testing.assert_array_equal(save_data, predicted_images[0][0], verbose=True) + np.testing.assert_array_equal(save_data, predicted_images[0], verbose=True) def test_initialization(prediction_writer_callback, write_strategy, dirpath): From d5178499d785e63dda0c5de8372be09c589b23ff Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 18:23:11 +0200 Subject: [PATCH 12/38] feat: add write_filenames attribute to WriteStrategy Protocol --- .../callbacks/prediction_writer_callback/write_strategy.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py index 45235af0..3b0977ff 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py @@ -17,6 +17,8 @@ class WriteStrategy(Protocol): """Protocol for write strategy classes.""" + write_filenames: Optional[list[str]] + def write_batch( self, trainer: Trainer, @@ -439,8 +441,8 @@ def write_batch( raise TypeError("Prediction dataset is not `IterablePredDataset`.") # for i in range(prediction.shape[0]): - # prediction_image = prediction[0] - # sample_id = batch_idx * dl.batch_size + i + # prediction_image = prediction[0] + # sample_id = batch_idx * dl.batch_size + i file_name = self.write_filenames[self.current_file_index] file_path = (dirpath / file_name).with_suffix(self.write_extension) From 95750e3f420ed0503369400d1b92439faf0ba380 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 18:24:24 +0200 Subject: [PATCH 13/38] docs: attr doctstring --- .../callbacks/prediction_writer_callback/write_strategy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py index 3b0977ff..e3939011 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py @@ -18,6 +18,7 @@ class WriteStrategy(Protocol): """Protocol for write strategy classes.""" write_filenames: Optional[list[str]] + """Filenames to write to.""" def write_batch( self, From 927df498a3db33ceccf06bd81c60ecd202740bd9 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 18:29:25 +0200 Subject: [PATCH 14/38] test: write strategy reset --- .../test_cache_tiles_write_strategy.py | 16 ++++++++++++++++ .../test_write_image_write_strategy.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index 086afac2..a8a06da2 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -313,3 +313,19 @@ def test_get_image_tiles(cache_tiles_strategy): assert len(image_tiles) == 9 assert all(np.array_equal(image_tiles[i], tiles[i]) for i in range(9)) assert image_tile_infos == tile_infos[:9] + + +def test_reset(cache_tiles_strategy: CacheTiles): + """Test CacheTiles.reset works as expected""" + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + # don't include last tile + patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) + + cache_tiles_strategy.write_filenames = ["file"] + cache_tiles_strategy.current_file_index = 1 + cache_tiles_strategy.reset() + assert cache_tiles_strategy.write_filenames is None + assert cache_tiles_strategy.current_file_index == 0 + assert len(cache_tiles_strategy.tile_cache) == 0 + assert len(cache_tiles_strategy.tile_info_cache) == 0 diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py index 98748dae..2f9ec426 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py @@ -100,3 +100,12 @@ def test_write_batch(write_image_strategy, ordered_array): np.testing.assert_array_equal( write_image_strategy.write_func.call_args.kwargs["img"], prediction ) + + +def test_reset(write_image_strategy: WriteImage): + """Test WriteImage.reset works as expected""" + write_image_strategy.write_filenames = ["file"] + write_image_strategy.current_file_index = 1 + write_image_strategy.reset() + assert write_image_strategy.write_filenames is None + assert write_image_strategy.current_file_index == 0 From 377c0cd8fa33eb79868eb26e193c6cc9a7b5da24 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 18:42:44 +0200 Subject: [PATCH 15/38] test: write_batch raises if write_filenames is None --- .../test_cache_tiles_write_strategy.py | 41 +++++++++++++++++++ .../test_write_image_write_strategy.py | 40 ++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index a8a06da2..b242172d 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -243,6 +243,47 @@ def test_write_batch_last_tile(cache_tiles_strategy): assert remaining_tile_info == cache_tiles_strategy.tile_info_cache[0] +def test_write_batch_raises(cache_tiles_strategy: CacheTiles): + """Test write batch raises a ValueError if the filenames have not been set.""" + # all tiles of 2 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=2) + + # simulate adding a batch that will contain the last tile + n_tiles = 8 + batch_size = 2 + patch_tile_cache(cache_tiles_strategy, tiles[:n_tiles], tile_infos[:n_tiles]) + next_batch = ( + np.concatenate(tiles[n_tiles : n_tiles + batch_size]), + tile_infos[n_tiles : n_tiles + batch_size], + ) + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterableTiledPredDataset) + dataloader_idx = 0 + trainer.predict_dataloaders = [Mock(spec=DataLoader)] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + with pytest.raises(ValueError): + assert cache_tiles_strategy.write_filenames is None + + # call write batch + dirpath = Path("predictions") + cache_tiles_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=next_batch, + batch_indices=Mock(), + batch=next_batch, # contains the last tile + batch_idx=3, + dataloader_idx=dataloader_idx, + dirpath=dirpath, + ) + + def test_have_last_tile_true(cache_tiles_strategy): """Test `CacheTiles._have_last_tile` returns true when there is a last tile.""" diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py index 2f9ec426..c0ff91dd 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py @@ -102,6 +102,46 @@ def test_write_batch(write_image_strategy, ordered_array): ) +def test_write_batch_raises(write_image_strategy, ordered_array): + """Test write batch raises a ValueError if the filenames have not been set.""" + n_batches = 1 + + prediction = ordered_array((n_batches, 1, 8, 8)) + + batch = prediction + batch_indices = np.arange(n_batches) + batch_idx = 0 + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterablePredDataset) + dataloader_idx = 0 + mock_dl = Mock(spec=DataLoader) + mock_dl.batch_size = 1 + trainer.predict_dataloaders = [mock_dl] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + # call write batch + dirpath = Path("predictions") + + with pytest.raises(ValueError): + # Make sure write_filenames is None + assert write_image_strategy.write_filenames is None + write_image_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=prediction, + batch_indices=batch_indices, + batch=batch, # contains the last tile + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + dirpath=dirpath, + ) + + def test_reset(write_image_strategy: WriteImage): """Test WriteImage.reset works as expected""" write_image_strategy.write_filenames = ["file"] From 42371f8feba42ee2293b3cf2907174562fec2f07 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Mon, 23 Sep 2024 18:56:26 +0200 Subject: [PATCH 16/38] test: filenames reset after predict call in smoke --- .../test_prediction_writer_callback.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py index 314c3c7d..e8119111 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py @@ -3,7 +3,7 @@ import os from pathlib import Path from typing import Union -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import numpy as np import pytest @@ -88,6 +88,7 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): write_strategy = create_write_strategy( write_type="tiff", tiled=True, write_filenames=[file_name] ) + write_strategy.reset = MagicMock(side_effect=write_strategy.reset) dirpath = tmp_path / "predictions" # create trainer @@ -122,7 +123,9 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): predicted = trainer.predict(model, datamodule=predict_data) predicted_images = convert_outputs(predicted, tiled=True) - # TODO: assert filenames reset after trainer.predict is called + # filenames reset after predictions called + write_strategy.reset.assert_called_once() + assert write_strategy.write_filenames is None # assert predicted file exists assert (dirpath / file_name).is_file() @@ -170,6 +173,7 @@ def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): write_strategy = create_write_strategy( write_type="tiff", tiled=False, write_filenames=[file_name] ) + write_strategy.reset = MagicMock(side_effect=write_strategy.reset) dirpath = tmp_path / "predictions" # create trainer @@ -202,6 +206,10 @@ def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): predicted = trainer.predict(model, datamodule=predict_data) predicted_images = convert_outputs(predicted, tiled=False) + # filenames reset after predictions called + write_strategy.reset.assert_called_once() + assert write_strategy.write_filenames is None + # assert predicted file exists assert (dirpath / file_name).is_file() From 362412385d299b4bf51d18fdc9a21d3c5d0cd67c Mon Sep 17 00:00:00 2001 From: melisande-c Date: Tue, 24 Sep 2024 15:59:50 +0200 Subject: [PATCH 17/38] refac: split write strategies into seperate modules --- .../write_strategy/__init__.py | 13 + .../cache_tiles.py} | 229 +----------------- .../write_strategy/protocol.py | 54 +++++ .../write_strategy/write_image.py | 141 +++++++++++ .../write_strategy/write_tiles_zarr.py | 55 +++++ .../test_cache_tiles_write_strategy.py | 2 +- 6 files changed, 268 insertions(+), 226 deletions(-) create mode 100644 src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py rename src/careamics/lightning/callbacks/prediction_writer_callback/{write_strategy.py => write_strategy/cache_tiles.py} (52%) create mode 100644 src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py create mode 100644 src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py create mode 100644 src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py new file mode 100644 index 00000000..344bf2df --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py @@ -0,0 +1,13 @@ +"""Write strategies for the prediciton writer callback.""" + +__all__ = [ + "WriteStrategy", + "CacheTiles", + "WriteImage", + "WriteTilesZarr", +] + +from .cache_tiles import CacheTiles +from .protocol import WriteStrategy +from .write_image import WriteImage +from .write_tiles_zarr import WriteTilesZarr diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py similarity index 52% rename from src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py rename to src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py index e3939011..6f0a6a3a 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py @@ -1,7 +1,7 @@ -"""Module containing different strategies for writing predictions.""" +"""Module containing the "cache tiles" write strategy.""" from pathlib import Path -from typing import Any, Optional, Protocol, Sequence, Union +from typing import Any, Optional, Sequence, Union import numpy as np from numpy.typing import NDArray @@ -9,57 +9,11 @@ from torch.utils.data import DataLoader from careamics.config.tile_information import TileInformation -from careamics.dataset import IterablePredDataset, IterableTiledPredDataset +from careamics.dataset import IterableTiledPredDataset from careamics.file_io import WriteFunc from careamics.prediction_utils import stitch_prediction_single - -class WriteStrategy(Protocol): - """Protocol for write strategy classes.""" - - write_filenames: Optional[list[str]] - """Filenames to write to.""" - - def write_batch( - self, - trainer: Trainer, - pl_module: LightningModule, - prediction: Any, # TODO: change to expected type - batch_indices: Optional[Sequence[int]], - batch: Any, # TODO: change to expected type - batch_idx: int, - dataloader_idx: int, - dirpath: Path, - ) -> None: - """ - WriteStrategy subclasses must contain this function to write a batch. - - Parameters - ---------- - trainer : Trainer - PyTorch Lightning Trainer. - pl_module : LightningModule - PyTorch Lightning LightningModule. - prediction : Any - Predictions on `batch`. - batch_indices : sequence of int - Indices identifying the samples in the batch. - batch : Any - Input batch. - batch_idx : int - Batch index. - dataloader_idx : int - Dataloader index. - dirpath : Path - Path to directory to save predictions to. - """ - - def reset(self) -> None: - """ - Reset internal attributes of a `WriteStrategy` instance. - - This is to prevent bugs if a `WriteStrategy` instance is used twice. - """ +from .protocol import WriteStrategy class CacheTiles(WriteStrategy): @@ -283,178 +237,3 @@ def _get_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]: tiles = self.tile_cache[: index + 1] tile_infos = self.tile_info_cache[: index + 1] return tiles, tile_infos - - -class WriteTilesZarr(WriteStrategy): - """Strategy to write tiles to Zarr file.""" - - def write_batch( - self, - trainer: Trainer, - pl_module: LightningModule, - prediction: Any, - batch_indices: Optional[Sequence[int]], - batch: Any, - batch_idx: int, - dataloader_idx: int, - dirpath: Path, - ) -> None: - """ - Write tiles to zarr file. - - Parameters - ---------- - trainer : Trainer - PyTorch Lightning Trainer. - pl_module : LightningModule - PyTorch Lightning LightningModule. - prediction : Any - Predictions on `batch`. - batch_indices : sequence of int - Indices identifying the samples in the batch. - batch : Any - Input batch. - batch_idx : int - Batch index. - dataloader_idx : int - Dataloader index. - dirpath : Path - Path to directory to save predictions to. - - Raises - ------ - NotImplementedError - """ - raise NotImplementedError - - def reset(self) -> None: - """Reset internal attributes.""" - raise NotImplementedError - - -class WriteImage(WriteStrategy): - """ - A strategy for writing image predictions (i.e. un-tiled predictions). - - Parameters - ---------- - write_func : WriteFunc - Function used to save predictions. - write_filenames : list of str, optional - A list of filenames in the order that predictions will be written in. - write_extension : str - Extension added to prediction file paths. - write_func_kwargs : dict of {str: Any} - Extra kwargs to pass to `write_func`. - - Attributes - ---------- - write_func : WriteFunc - Function used to save predictions. - write_filenames : list of str, optional - A list of filenames in the order that predictions will be written in. - write_extension : str - Extension added to prediction file paths. - write_func_kwargs : dict of {str: Any} - Extra kwargs to pass to `write_func`. - current_file_index : int - Index of current file, increments every time a file is written. - """ - - def __init__( - self, - write_func: WriteFunc, - write_filenames: Optional[list[str]], - write_extension: str, - write_func_kwargs: dict[str, Any], - ) -> None: - """ - A strategy for writing image predictions (i.e. un-tiled predictions). - - Parameters - ---------- - write_func : WriteFunc - Function used to save predictions. - write_filenames : list of str, optional - A list of filenames in the order that predictions will be written in. - write_extension : str - Extension added to prediction file paths. - write_func_kwargs : dict of {str: Any} - Extra kwargs to pass to `write_func`. - """ - super().__init__() - - self.write_func: WriteFunc = write_func - self.write_filenames: Optional[list[str]] = write_filenames - self.write_extension: str = write_extension - self.write_func_kwargs: dict[str, Any] = write_func_kwargs - - self.current_file_index: int = 0 - - def write_batch( - self, - trainer: Trainer, - pl_module: LightningModule, - prediction: NDArray, - batch_indices: Optional[Sequence[int]], - batch: NDArray, - batch_idx: int, - dataloader_idx: int, - dirpath: Path, - ) -> None: - """ - Save full images. - - Parameters - ---------- - trainer : Trainer - PyTorch Lightning Trainer. - pl_module : LightningModule - PyTorch Lightning LightningModule. - prediction : Any - Predictions on `batch`. - batch_indices : sequence of int - Indices identifying the samples in the batch. - batch : Any - Input batch. - batch_idx : int - Batch index. - dataloader_idx : int - Dataloader index. - dirpath : Path - Path to directory to save predictions to. - - Raises - ------ - TypeError - If trainer prediction dataset is not `IterablePredDataset`. - ValueError - If `write_filenames` attribute is `None`. - """ - if self.write_filenames is None: - raise ValueError("`write_filenames` attribute has not been set.") - - dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders - dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls - ds: IterablePredDataset = dl.dataset - if not isinstance(ds, IterablePredDataset): - # TODO: change to warning - raise TypeError("Prediction dataset is not `IterablePredDataset`.") - - # for i in range(prediction.shape[0]): - # prediction_image = prediction[0] - # sample_id = batch_idx * dl.batch_size + i - - file_name = self.write_filenames[self.current_file_index] - file_path = (dirpath / file_name).with_suffix(self.write_extension) - self.write_func(file_path=file_path, img=prediction, **self.write_func_kwargs) - self.current_file_index += 1 - - def reset(self) -> None: - """ - Reset internal attributes. - - Resets the `write_filenames` and `current_file_index` attributes. - """ - self.write_filenames = None - self.current_file_index = 0 diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py new file mode 100644 index 00000000..204e391e --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py @@ -0,0 +1,54 @@ +"""Module containing the protocol that defines the WriteStrategy interface.""" + +from pathlib import Path +from typing import Any, Optional, Protocol, Sequence + +from pytorch_lightning import LightningModule, Trainer + + +class WriteStrategy(Protocol): + """Protocol for write strategy classes.""" + + write_filenames: Optional[list[str]] + """Filenames to write to.""" + + def write_batch( + self, + trainer: Trainer, + pl_module: LightningModule, + prediction: Any, # TODO: change to expected type + batch_indices: Optional[Sequence[int]], + batch: Any, # TODO: change to expected type + batch_idx: int, + dataloader_idx: int, + dirpath: Path, + ) -> None: + """ + WriteStrategy subclasses must contain this function to write a batch. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning Trainer. + pl_module : LightningModule + PyTorch Lightning LightningModule. + prediction : Any + Predictions on `batch`. + batch_indices : sequence of int + Indices identifying the samples in the batch. + batch : Any + Input batch. + batch_idx : int + Batch index. + dataloader_idx : int + Dataloader index. + dirpath : Path + Path to directory to save predictions to. + """ + + def reset(self) -> None: + """ + Reset internal attributes of a `WriteStrategy` instance. + + This is to prevent bugs if a `WriteStrategy` instance is used twice. + """ diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py new file mode 100644 index 00000000..618b5856 --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -0,0 +1,141 @@ +"""Module containing write strategy for when batches contain full images.""" + +from pathlib import Path +from typing import Any, Optional, Sequence, Union + +from numpy.typing import NDArray +from pytorch_lightning import LightningModule, Trainer +from torch.utils.data import DataLoader + +from careamics.dataset import IterablePredDataset +from careamics.file_io import WriteFunc + +from .protocol import WriteStrategy + + +class WriteImage(WriteStrategy): + """ + A strategy for writing image predictions (i.e. un-tiled predictions). + + Parameters + ---------- + write_func : WriteFunc + Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + write_extension : str + Extension added to prediction file paths. + write_func_kwargs : dict of {str: Any} + Extra kwargs to pass to `write_func`. + + Attributes + ---------- + write_func : WriteFunc + Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + write_extension : str + Extension added to prediction file paths. + write_func_kwargs : dict of {str: Any} + Extra kwargs to pass to `write_func`. + current_file_index : int + Index of current file, increments every time a file is written. + """ + + def __init__( + self, + write_func: WriteFunc, + write_filenames: Optional[list[str]], + write_extension: str, + write_func_kwargs: dict[str, Any], + ) -> None: + """ + A strategy for writing image predictions (i.e. un-tiled predictions). + + Parameters + ---------- + write_func : WriteFunc + Function used to save predictions. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + write_extension : str + Extension added to prediction file paths. + write_func_kwargs : dict of {str: Any} + Extra kwargs to pass to `write_func`. + """ + super().__init__() + + self.write_func: WriteFunc = write_func + self.write_filenames: Optional[list[str]] = write_filenames + self.write_extension: str = write_extension + self.write_func_kwargs: dict[str, Any] = write_func_kwargs + + self.current_file_index: int = 0 + + def write_batch( + self, + trainer: Trainer, + pl_module: LightningModule, + prediction: NDArray, + batch_indices: Optional[Sequence[int]], + batch: NDArray, + batch_idx: int, + dataloader_idx: int, + dirpath: Path, + ) -> None: + """ + Save full images. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning Trainer. + pl_module : LightningModule + PyTorch Lightning LightningModule. + prediction : Any + Predictions on `batch`. + batch_indices : sequence of int + Indices identifying the samples in the batch. + batch : Any + Input batch. + batch_idx : int + Batch index. + dataloader_idx : int + Dataloader index. + dirpath : Path + Path to directory to save predictions to. + + Raises + ------ + TypeError + If trainer prediction dataset is not `IterablePredDataset`. + ValueError + If `write_filenames` attribute is `None`. + """ + if self.write_filenames is None: + raise ValueError("`write_filenames` attribute has not been set.") + + dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders + dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls + ds: IterablePredDataset = dl.dataset + if not isinstance(ds, IterablePredDataset): + # TODO: change to warning + raise TypeError("Prediction dataset is not `IterablePredDataset`.") + + # for i in range(prediction.shape[0]): + # prediction_image = prediction[0] + # sample_id = batch_idx * dl.batch_size + i + + file_name = self.write_filenames[self.current_file_index] + file_path = (dirpath / file_name).with_suffix(self.write_extension) + self.write_func(file_path=file_path, img=prediction, **self.write_func_kwargs) + self.current_file_index += 1 + + def reset(self) -> None: + """ + Reset internal attributes. + + Resets the `write_filenames` and `current_file_index` attributes. + """ + self.write_filenames = None + self.current_file_index = 0 diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py new file mode 100644 index 00000000..2dbd11da --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py @@ -0,0 +1,55 @@ +"""Module containing a write strategy for writing tiles directly to zarr datasets.""" + +from pathlib import Path +from typing import Any, Optional, Sequence + +from pytorch_lightning import LightningModule, Trainer + +from .protocol import WriteStrategy + + +class WriteTilesZarr(WriteStrategy): + """Strategy to write tiles to Zarr file.""" + + def write_batch( + self, + trainer: Trainer, + pl_module: LightningModule, + prediction: Any, + batch_indices: Optional[Sequence[int]], + batch: Any, + batch_idx: int, + dataloader_idx: int, + dirpath: Path, + ) -> None: + """ + Write tiles to zarr file. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning Trainer. + pl_module : LightningModule + PyTorch Lightning LightningModule. + prediction : Any + Predictions on `batch`. + batch_indices : sequence of int + Indices identifying the samples in the batch. + batch : Any + Input batch. + batch_idx : int + Batch index. + dataloader_idx : int + Dataloader index. + dirpath : Path + Path to directory to save predictions to. + + Raises + ------ + NotImplementedError + """ + raise NotImplementedError + + def reset(self) -> None: + """Reset internal attributes.""" + raise NotImplementedError diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index b242172d..93ea0838 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -207,7 +207,7 @@ def test_write_batch_last_tile(cache_tiles_strategy): # This is a unit test to isolate functionality of `write_batch.` with patch( "careamics.lightning.callbacks.prediction_writer_callback.write_strategy" - + ".stitch_prediction_single", + + ".cache_tiles.stitch_prediction_single", ) as mock_stitch_prediction_single: prediction_image = [Mock()] From b4fecacae71b161ba13758804702edd47eac5f7a Mon Sep 17 00:00:00 2001 From: melisande-c Date: Tue, 24 Sep 2024 16:44:42 +0200 Subject: [PATCH 18/38] refac: remove inheritance from protocol --- .../prediction_writer_callback/write_strategy/cache_tiles.py | 4 +--- .../prediction_writer_callback/write_strategy/write_image.py | 4 +--- .../write_strategy/write_tiles_zarr.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py index 6f0a6a3a..92ec01e2 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py @@ -13,10 +13,8 @@ from careamics.file_io import WriteFunc from careamics.prediction_utils import stitch_prediction_single -from .protocol import WriteStrategy - -class CacheTiles(WriteStrategy): +class CacheTiles: """ A write strategy that will cache tiles. diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py index 618b5856..a19ac65c 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -10,10 +10,8 @@ from careamics.dataset import IterablePredDataset from careamics.file_io import WriteFunc -from .protocol import WriteStrategy - -class WriteImage(WriteStrategy): +class WriteImage: """ A strategy for writing image predictions (i.e. un-tiled predictions). diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py index 2dbd11da..d3a6dcc3 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles_zarr.py @@ -5,10 +5,8 @@ from pytorch_lightning import LightningModule, Trainer -from .protocol import WriteStrategy - -class WriteTilesZarr(WriteStrategy): +class WriteTilesZarr: """Strategy to write tiles to Zarr file.""" def write_batch( From e062295ad79dd885478f348579c52ef32686b651 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Tue, 24 Sep 2024 16:49:16 +0200 Subject: [PATCH 19/38] refac: rename CacheTiles to WriteTiles --- .../callbacks/prediction_writer_callback/__init__.py | 4 ++-- .../write_strategy/__init__.py | 4 ++-- .../{cache_tiles.py => write_tiles.py} | 2 +- .../write_strategy_factory.py | 4 ++-- .../test_cache_tiles_write_strategy.py | 12 ++++++------ .../test_write_strategy_factory.py | 6 +++--- 6 files changed, 16 insertions(+), 16 deletions(-) rename src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/{cache_tiles.py => write_tiles.py} (99%) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py b/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py index ffbc1209..331b869b 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py @@ -5,14 +5,14 @@ "create_write_strategy", "WriteStrategy", "WriteImage", - "CacheTiles", + "WriteTiles", "WriteTilesZarr", "select_write_extension", "select_write_func", ] from .prediction_writer_callback import PredictionWriterCallback -from .write_strategy import CacheTiles, WriteImage, WriteStrategy, WriteTilesZarr +from .write_strategy import WriteImage, WriteStrategy, WriteTiles, WriteTilesZarr from .write_strategy_factory import ( create_write_strategy, select_write_extension, diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py index 344bf2df..43feb7a9 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/__init__.py @@ -2,12 +2,12 @@ __all__ = [ "WriteStrategy", - "CacheTiles", + "WriteTiles", "WriteImage", "WriteTilesZarr", ] -from .cache_tiles import CacheTiles from .protocol import WriteStrategy from .write_image import WriteImage +from .write_tiles import WriteTiles from .write_tiles_zarr import WriteTilesZarr diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py similarity index 99% rename from src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py rename to src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py index 92ec01e2..00c73943 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/cache_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -14,7 +14,7 @@ from careamics.prediction_utils import stitch_prediction_single -class CacheTiles: +class WriteTiles: """ A write strategy that will cache tiles. diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py index dd35016b..a7427a4d 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py @@ -5,7 +5,7 @@ from careamics.config.support import SupportedData from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func -from .write_strategy import CacheTiles, WriteImage, WriteStrategy +from .write_strategy import WriteImage, WriteStrategy, WriteTiles def create_write_strategy( @@ -130,7 +130,7 @@ def _create_tiled_write_strategy( write_extension = select_write_extension( write_type=write_type, write_extension=write_extension ) - return CacheTiles( + return WriteTiles( write_func=write_func, write_filenames=write_filenames, write_extension=write_extension, diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index 93ea0838..640b065e 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -14,7 +14,7 @@ from careamics.dataset.tiling import extract_tiles from careamics.file_io import WriteFunc from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( - CacheTiles, + WriteTiles, ) @@ -52,7 +52,7 @@ def create_tiles(n_samples: int) -> tuple[list[NDArray], list[TileInformation]]: def patch_tile_cache( - strategy: CacheTiles, tiles: list[NDArray], tile_infos: list[TileInformation] + strategy: WriteTiles, tiles: list[NDArray], tile_infos: list[TileInformation] ) -> None: """ Patch simulated tile cache into `strategy`. @@ -77,7 +77,7 @@ def write_func(): @pytest.fixture -def cache_tiles_strategy(write_func) -> CacheTiles: +def cache_tiles_strategy(write_func) -> WriteTiles: """ Initialized `CacheTiles` class. @@ -93,7 +93,7 @@ def cache_tiles_strategy(write_func) -> CacheTiles: """ write_extension = ".ext" write_func_kwargs = {} - return CacheTiles( + return WriteTiles( write_func=write_func, write_filenames=None, write_extension=write_extension, @@ -243,7 +243,7 @@ def test_write_batch_last_tile(cache_tiles_strategy): assert remaining_tile_info == cache_tiles_strategy.tile_info_cache[0] -def test_write_batch_raises(cache_tiles_strategy: CacheTiles): +def test_write_batch_raises(cache_tiles_strategy: WriteTiles): """Test write batch raises a ValueError if the filenames have not been set.""" # all tiles of 2 samples with 9 tiles tiles, tile_infos = create_tiles(n_samples=2) @@ -356,7 +356,7 @@ def test_get_image_tiles(cache_tiles_strategy): assert image_tile_infos == tile_infos[:9] -def test_reset(cache_tiles_strategy: CacheTiles): +def test_reset(cache_tiles_strategy: WriteTiles): """Test CacheTiles.reset works as expected""" # all tiles of 1 samples with 9 tiles tiles, tile_infos = create_tiles(n_samples=1) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py index f00cffde..72fc6e35 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py @@ -8,8 +8,8 @@ from careamics.file_io.write import write_tiff from careamics.lightning.callbacks.prediction_writer_callback import ( - CacheTiles, WriteImage, + WriteTiles, create_write_strategy, select_write_extension, select_write_func, @@ -25,7 +25,7 @@ def test_create_write_strategy_tiff_tiled(): """Test write strategy creation for tiled tiff.""" write_strategy = create_write_strategy(write_type="tiff", tiled=True) - assert isinstance(write_strategy, CacheTiles) + assert isinstance(write_strategy, WriteTiles) assert write_strategy.write_func is write_tiff assert write_strategy.write_extension == ".tiff" assert write_strategy.write_func_kwargs == {} @@ -46,7 +46,7 @@ def test_create_write_strategy_custom_tiled(): write_strategy = create_write_strategy( write_type="custom", tiled=True, write_func=save_numpy, write_extension=".npy" ) - assert isinstance(write_strategy, CacheTiles) + assert isinstance(write_strategy, WriteTiles) assert write_strategy.write_func is save_numpy assert write_strategy.write_extension == ".npy" assert write_strategy.write_func_kwargs == {} From b6fa4efcdabed3e4fb0a31b6aaaea6a4daaeaeed Mon Sep 17 00:00:00 2001 From: melisande-c Date: Tue, 24 Sep 2024 17:20:48 +0200 Subject: [PATCH 20/38] refac: extract tile cache to seperate class --- .../write_strategy/utils.py | 40 +++++++++++ .../write_strategy/write_tiles.py | 71 +++---------------- 2 files changed, 49 insertions(+), 62 deletions(-) create mode 100644 src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py new file mode 100644 index 00000000..212416f9 --- /dev/null +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py @@ -0,0 +1,40 @@ + +import numpy as np +from numpy.typing import NDArray + +from careamics.config.tile_information import TileInformation + +class TileCache: + """ + Cache tiles; logic to pop tiles when tiles from a full image have been stored. + """ + + def __init__(self): + self.array_cache: list[NDArray] = [] + self.tile_info_cache: list[TileInformation] = [] + + def add(self, item: tuple[NDArray, list[TileInformation]]): + self.array_cache.extend(np.split(item[0]), item[0].shape[0]) + self.tile_info_cache.extend(item[1]) + + def has_last_tile(self) -> bool: + return any(tile_info.last_tile for tile_info in self.tile_info_cache) + + def pop_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]: + is_last_tile = [tile_info.last_tile for tile_info in self.tile_info_cache] + if not any(is_last_tile): + raise ValueError("No last tile in cache.") + + index = np.where(is_last_tile)[0][0] + # get image tiles + tiles = self.array_cache[: index + 1] + tile_infos = self.tile_info_cache[: index + 1] + # remove image tiles from list + self.array_cache = self.array_cache[index + 1 :] + self.tile_info_cache = self.tile_info_cache[index + 1 :] + + return tiles, tile_infos + + def reset(self): + self.array_cache = [] + self.tile_info_cache = [] \ No newline at end of file diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py index 00c73943..161b09bd 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -13,6 +13,7 @@ from careamics.file_io import WriteFunc from careamics.prediction_utils import stitch_prediction_single +from .utils import TileCache class WriteTiles: """ @@ -82,8 +83,7 @@ def __init__( self.write_func_kwargs: dict[str, Any] = write_func_kwargs # where tiles will be cached until a whole image has been predicted - self.tile_cache: list[NDArray] = [] - self.tile_info_cache: list[TileInformation] = [] + self.tile_cache = TileCache() self.current_file_index = 0 @@ -139,6 +139,8 @@ def write_batch( """ if self.write_filenames is None: raise ValueError("`write_filenames` attribute has not been set.") + + # TODO: move dataset type check somewhere else dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders dataloader: DataLoader = ( dataloaders[dataloader_idx] @@ -149,16 +151,13 @@ def write_batch( if not isinstance(dataset, IterableTiledPredDataset): raise TypeError("Prediction dataset is not `IterableTiledPredDataset`.") - # cache tiles (batches are split into single samples) - self.tile_cache.extend(np.split(prediction[0], prediction[0].shape[0])) - self.tile_info_cache.extend(prediction[1]) + self.tile_cache.add(prediction) # save stitched prediction - if self._has_last_tile(): + if self.tile_cache.has_last_tile(): # get image tiles and remove them from the cache - tiles, tile_infos = self._get_image_tiles() - self._clear_cache() + tiles, tile_infos = self.tile_cache.pop_image_tiles() # stitch prediction prediction_image = stitch_prediction_single( @@ -177,61 +176,9 @@ def reset(self) -> None: """ Reset the internal attributes. - Attributes reset are: `write_filenames`, `tile_cache`, `tile_info_cache` and - `current_file_index`. + Attributes reset are: `write_filenames`, `tile_cache`, and `current_file_index`. """ self.write_filenames = None - self.tile_cache = [] - self.tile_info_cache = [] self.current_file_index = 0 + self.tile_cache.reset() - def _has_last_tile(self) -> bool: - """ - Whether a last tile is contained in the cached tiles. - - Returns - ------- - bool - Whether a last tile is contained in the cached tiles. - """ - return any(self.last_tiles) - - def _clear_cache(self) -> None: - """Remove the tiles in the cache up to the first last tile.""" - index = self._last_tile_index() - self.tile_cache = self.tile_cache[index + 1 :] - self.tile_info_cache = self.tile_info_cache[index + 1 :] - - def _last_tile_index(self) -> int: - """ - Find the index of the last tile in the tile cache. - - Returns - ------- - int - Index of last tile. - - Raises - ------ - ValueError - If there is no last tile in the tile cache. - """ - last_tiles = self.last_tiles - if not any(last_tiles): - raise ValueError("No last tile in the tile cache.") - index = np.where(last_tiles)[0][0] - return index - - def _get_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]: - """ - Get the tiles corresponding to a single image. - - Returns - ------- - tuple of (list of numpy.ndarray, list of TileInformation) - Tiles and tile information to stitch together a full image. - """ - index = self._last_tile_index() - tiles = self.tile_cache[: index + 1] - tile_infos = self.tile_info_cache[: index + 1] - return tiles, tile_infos From 88e8e9d37c8f2ceee85f575e6b6b1d80f617c37a Mon Sep 17 00:00:00 2001 From: melisande-c Date: Wed, 25 Sep 2024 12:16:21 +0200 Subject: [PATCH 21/38] feat: add SampleCache class --- .../write_strategy/utils.py | 53 +++++++++++++++++-- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py index 212416f9..d86fa86a 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py @@ -1,9 +1,11 @@ +from typing import Optional import numpy as np from numpy.typing import NDArray from careamics.config.tile_information import TileInformation + class TileCache: """ Cache tiles; logic to pop tiles when tiles from a full image have been stored. @@ -14,17 +16,17 @@ def __init__(self): self.tile_info_cache: list[TileInformation] = [] def add(self, item: tuple[NDArray, list[TileInformation]]): - self.array_cache.extend(np.split(item[0]), item[0].shape[0]) + self.array_cache.extend(np.split(item[0], item[0].shape[0])) self.tile_info_cache.extend(item[1]) def has_last_tile(self) -> bool: return any(tile_info.last_tile for tile_info in self.tile_info_cache) - + def pop_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]: is_last_tile = [tile_info.last_tile for tile_info in self.tile_info_cache] if not any(is_last_tile): raise ValueError("No last tile in cache.") - + index = np.where(is_last_tile)[0][0] # get image tiles tiles = self.array_cache[: index + 1] @@ -34,7 +36,48 @@ def pop_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]: self.tile_info_cache = self.tile_info_cache[index + 1 :] return tiles, tile_infos - + def reset(self): self.array_cache = [] - self.tile_info_cache = [] \ No newline at end of file + self.tile_info_cache = [] + + +class SampleCache: + + def __init__(self, n_samples_per_file: list[int]): + + self.n_samples_per_file: list[int] = n_samples_per_file + self.n_samples_iter = iter(self.n_samples_per_file) + self.n_samples: Optional[int] = next(self.n_samples_iter) + self.sample_cache: list[NDArray] = [] + + def add(self, item: NDArray): + self.sample_cache.extend(np.split(item, item.shape[0])) + + def has_all_file_samples(self) -> bool: + if self.n_samples is None: + raise ValueError( + "Number of samples for current file is unknown. Reached the end of the " + "given list of samples per file." + ) + return len(self.sample_cache) >= self.n_samples + + def pop_file_samples(self) -> list[NDArray]: + if not self.has_all_file_samples(): + raise ValueError( + "Do not have all the samples belonging to the current file." + ) + + samples = self.sample_cache[: self.n_samples] + self.sample_cache = self.sample_cache[self.n_samples :] + + try: + self.n_samples = next(self.n_samples_iter) + except StopIteration: + self.n_samples = None + + return samples + + def reset(self): + self.n_samples_iter = iter(self.n_samples_per_file) + self.sample_cache: list[NDArray] = [] From 39182d61d0be826fdd5cdd71d253c930c839e74e Mon Sep 17 00:00:00 2001 From: melisande-c Date: Wed, 25 Sep 2024 13:58:15 +0200 Subject: [PATCH 22/38] feat: add sample caching to write strategies --- .../write_strategy/utils.py | 6 +- .../write_strategy/write_image.py | 20 +++++- .../write_strategy/write_tiles.py | 61 ++++++++++--------- 3 files changed, 53 insertions(+), 34 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py index d86fa86a..b7603333 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py @@ -44,7 +44,7 @@ def reset(self): class SampleCache: - def __init__(self, n_samples_per_file: list[int]): + def __init__(self, n_samples_per_file: Optional[list[int]]): self.n_samples_per_file: list[int] = n_samples_per_file self.n_samples_iter = iter(self.n_samples_per_file) @@ -58,7 +58,7 @@ def has_all_file_samples(self) -> bool: if self.n_samples is None: raise ValueError( "Number of samples for current file is unknown. Reached the end of the " - "given list of samples per file." + "given list of samples per file, or a list has not been given." ) return len(self.sample_cache) >= self.n_samples @@ -77,7 +77,7 @@ def pop_file_samples(self) -> list[NDArray]: self.n_samples = None return samples - + def reset(self): self.n_samples_iter = iter(self.n_samples_per_file) self.sample_cache: list[NDArray] = [] diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py index a19ac65c..f6bbb123 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, Optional, Sequence, Union +import numpy as np from numpy.typing import NDArray from pytorch_lightning import LightningModule, Trainer from torch.utils.data import DataLoader @@ -10,6 +11,8 @@ from careamics.dataset import IterablePredDataset from careamics.file_io import WriteFunc +from .utils import SampleCache + class WriteImage: """ @@ -46,6 +49,7 @@ def __init__( write_filenames: Optional[list[str]], write_extension: str, write_func_kwargs: dict[str, Any], + n_samples_per_file: Optional[list[int]], ) -> None: """ A strategy for writing image predictions (i.e. un-tiled predictions). @@ -68,6 +72,9 @@ def __init__( self.write_extension: str = write_extension self.write_func_kwargs: dict[str, Any] = write_func_kwargs + # where samples are stored until a whole file has been predicted + self.sample_cache = SampleCache(n_samples_per_file) + self.current_file_index: int = 0 def write_batch( @@ -124,9 +131,20 @@ def write_batch( # prediction_image = prediction[0] # sample_id = batch_idx * dl.batch_size + i + self.sample_cache.add(prediction) + # early return + if not self.sample_cache.has_all_file_samples(): + return + + # if has all samples in file + samples = self.sample_cache.pop_file_samples() + + # combine + data = np.concatenate(samples) + file_name = self.write_filenames[self.current_file_index] file_path = (dirpath / file_name).with_suffix(self.write_extension) - self.write_func(file_path=file_path, img=prediction, **self.write_func_kwargs) + self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) self.current_file_index += 1 def reset(self) -> None: diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py index 161b09bd..9d35913f 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -13,7 +13,8 @@ from careamics.file_io import WriteFunc from careamics.prediction_utils import stitch_prediction_single -from .utils import TileCache +from .utils import SampleCache, TileCache + class WriteTiles: """ @@ -57,6 +58,7 @@ def __init__( write_filenames: Optional[list[str]], write_extension: str, write_func_kwargs: dict[str, Any], + n_samples_per_file: Optional[list[int]], ) -> None: """ A write strategy that will cache tiles. @@ -84,21 +86,11 @@ def __init__( # where tiles will be cached until a whole image has been predicted self.tile_cache = TileCache() + # where samples are stored until a whole file has been predicted + self.sample_cache = SampleCache(n_samples_per_file) self.current_file_index = 0 - @property - def last_tiles(self) -> list[bool]: - """ - List of bool to determine whether each tile in the cache is the last tile. - - Returns - ------- - list of bool - Whether each tile in the tile cache is the last tile. - """ - return [tile_info.last_tile for tile_info in self.tile_info_cache] - def write_batch( self, trainer: Trainer, @@ -139,7 +131,7 @@ def write_batch( """ if self.write_filenames is None: raise ValueError("`write_filenames` attribute has not been set.") - + # TODO: move dataset type check somewhere else dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders dataloader: DataLoader = ( @@ -153,24 +145,33 @@ def write_batch( self.tile_cache.add(prediction) - # save stitched prediction - if self.tile_cache.has_last_tile(): + # early return + if not self.tile_cache.has_last_tile(): + return - # get image tiles and remove them from the cache - tiles, tile_infos = self.tile_cache.pop_image_tiles() + # if has last tile + tiles, tile_infos = self.tile_cache.pop_image_tiles() - # stitch prediction - prediction_image = stitch_prediction_single( - tiles=tiles, tile_infos=tile_infos - ) + # stitch prediction + prediction_image = stitch_prediction_single(tiles=tiles, tile_infos=tile_infos) - # write prediction - file_name = self.write_filenames[self.current_file_index] - file_path = (dirpath / file_name).with_suffix(self.write_extension) - self.write_func( - file_path=file_path, img=prediction_image[0], **self.write_func_kwargs - ) - self.current_file_index += 1 + self.sample_cache.add(prediction_image) + + # early return + if not self.sample_cache.has_all_file_samples(): + return + + # if has all samples in file + samples = self.sample_cache.pop_file_samples() + + # combine + data = np.concatenate(samples) + + # write prediction + file_name = self.write_filenames[self.current_file_index] + file_path = (dirpath / file_name).with_suffix(self.write_extension) + self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) + self.current_file_index += 1 def reset(self) -> None: """ @@ -181,4 +182,4 @@ def reset(self) -> None: self.write_filenames = None self.current_file_index = 0 self.tile_cache.reset() - + self.sample_cache.reset() From 67966b335a2e8a72ae4587c3b45831238bd104c0 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Wed, 25 Sep 2024 15:00:12 +0200 Subject: [PATCH 23/38] feat(write strategies): method to set file data --- .../write_strategy/write_image.py | 14 +++++++++++++- .../write_strategy/write_tiles.py | 14 +++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py index f6bbb123..5b104678 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -46,9 +46,9 @@ class WriteImage: def __init__( self, write_func: WriteFunc, - write_filenames: Optional[list[str]], write_extension: str, write_func_kwargs: dict[str, Any], + write_filenames: Optional[list[str]], n_samples_per_file: Optional[list[int]], ) -> None: """ @@ -77,6 +77,18 @@ def __init__( self.current_file_index: int = 0 + if write_filenames is not None and n_samples_per_file is not None: + self.set_file_data(write_filenames, n_samples_per_file) + + def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]): + if len(write_filenames) != len(n_samples_per_file): + raise ValueError( + "List of filename and list of number of samples per file are not of " + "equal length." + ) + self.write_filenames = write_filenames + self.sample_cache.n_samples_per_file = n_samples_per_file + def write_batch( self, trainer: Trainer, diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py index 9d35913f..d7c37986 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -55,9 +55,9 @@ class WriteTiles: def __init__( self, write_func: WriteFunc, - write_filenames: Optional[list[str]], write_extension: str, write_func_kwargs: dict[str, Any], + write_filenames: Optional[list[str]], n_samples_per_file: Optional[list[int]], ) -> None: """ @@ -91,6 +91,18 @@ def __init__( self.current_file_index = 0 + if write_filenames is not None and n_samples_per_file is not None: + self.set_file_data(write_filenames, n_samples_per_file) + + def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]): + if len(write_filenames) != len(n_samples_per_file): + raise ValueError( + "List of filename and list of number of samples per file are not of " + "equal length." + ) + self.write_filenames = write_filenames + self.sample_cache.n_samples_per_file = n_samples_per_file + def write_batch( self, trainer: Trainer, From bb86d1fbfdb279ebd7cb4c36188bc2c76ada5c87 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Wed, 25 Sep 2024 15:10:33 +0200 Subject: [PATCH 24/38] feat(write strategie protocol): method to set file data --- .../prediction_writer_callback/write_strategy/protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py index 204e391e..cb5dcd1b 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py @@ -9,9 +9,6 @@ class WriteStrategy(Protocol): """Protocol for write strategy classes.""" - write_filenames: Optional[list[str]] - """Filenames to write to.""" - def write_batch( self, trainer: Trainer, @@ -46,6 +43,9 @@ def write_batch( Path to directory to save predictions to. """ + def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]) -> None: + ... + def reset(self) -> None: """ Reset internal attributes of a `WriteStrategy` instance. From ee9be04f61fc563f5a7a44dae4dac0a913e27408 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Wed, 25 Sep 2024 15:44:59 +0200 Subject: [PATCH 25/38] feat: replace filename indexing with iterator --- .../write_strategy/write_image.py | 25 +++++++++---------- .../write_strategy/write_tiles.py | 21 ++++++++-------- .../test_cache_tiles_write_strategy.py | 6 ++--- .../test_write_image_write_strategy.py | 4 +-- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py index 5b104678..077d1b81 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -1,7 +1,7 @@ """Module containing write strategy for when batches contain full images.""" from pathlib import Path -from typing import Any, Optional, Sequence, Union +from typing import Any, Iterator, Optional, Sequence, Union import numpy as np from numpy.typing import NDArray @@ -68,15 +68,16 @@ def __init__( super().__init__() self.write_func: WriteFunc = write_func - self.write_filenames: Optional[list[str]] = write_filenames self.write_extension: str = write_extension self.write_func_kwargs: dict[str, Any] = write_func_kwargs # where samples are stored until a whole file has been predicted self.sample_cache = SampleCache(n_samples_per_file) - self.current_file_index: int = 0 - + self._write_filenames: Optional[list[str]] = write_filenames + self.filename_iter: Optional[Iterator[str]] = ( + iter(write_filenames) if write_filenames is not None else None + ) if write_filenames is not None and n_samples_per_file is not None: self.set_file_data(write_filenames, n_samples_per_file) @@ -86,7 +87,8 @@ def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int "List of filename and list of number of samples per file are not of " "equal length." ) - self.write_filenames = write_filenames + self._write_filenames = write_filenames + self.filename_iter = iter(write_filenames) self.sample_cache.n_samples_per_file = n_samples_per_file def write_batch( @@ -129,7 +131,7 @@ def write_batch( ValueError If `write_filenames` attribute is `None`. """ - if self.write_filenames is None: + if self._write_filenames is None: raise ValueError("`write_filenames` attribute has not been set.") dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders @@ -139,10 +141,6 @@ def write_batch( # TODO: change to warning raise TypeError("Prediction dataset is not `IterablePredDataset`.") - # for i in range(prediction.shape[0]): - # prediction_image = prediction[0] - # sample_id = batch_idx * dl.batch_size + i - self.sample_cache.add(prediction) # early return if not self.sample_cache.has_all_file_samples(): @@ -154,10 +152,10 @@ def write_batch( # combine data = np.concatenate(samples) - file_name = self.write_filenames[self.current_file_index] + # write prediction + file_name = next(self.filename_iter) file_path = (dirpath / file_name).with_suffix(self.write_extension) self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) - self.current_file_index += 1 def reset(self) -> None: """ @@ -165,5 +163,6 @@ def reset(self) -> None: Resets the `write_filenames` and `current_file_index` attributes. """ - self.write_filenames = None + self._write_filenames = None + self.filename_iter = None self.current_file_index = 0 diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py index d7c37986..3bae5d3b 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -1,7 +1,7 @@ """Module containing the "cache tiles" write strategy.""" from pathlib import Path -from typing import Any, Optional, Sequence, Union +from typing import Any, Iterator, Optional, Sequence, Union import numpy as np from numpy.typing import NDArray @@ -80,7 +80,6 @@ def __init__( super().__init__() self.write_func: WriteFunc = write_func - self.write_filenames: Optional[list[str]] = write_filenames self.write_extension: str = write_extension self.write_func_kwargs: dict[str, Any] = write_func_kwargs @@ -89,8 +88,10 @@ def __init__( # where samples are stored until a whole file has been predicted self.sample_cache = SampleCache(n_samples_per_file) - self.current_file_index = 0 - + self._write_filenames: Optional[list[str]] = write_filenames + self.filename_iter: Optional[Iterator[str]] = ( + iter(write_filenames) if write_filenames is not None else None + ) if write_filenames is not None and n_samples_per_file is not None: self.set_file_data(write_filenames, n_samples_per_file) @@ -100,7 +101,8 @@ def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int "List of filename and list of number of samples per file are not of " "equal length." ) - self.write_filenames = write_filenames + self._write_filenames = write_filenames + self.filename_iter = iter(write_filenames) self.sample_cache.n_samples_per_file = n_samples_per_file def write_batch( @@ -141,7 +143,7 @@ def write_batch( ValueError If `write_filenames` attribute is `None`. """ - if self.write_filenames is None: + if self._write_filenames is None: raise ValueError("`write_filenames` attribute has not been set.") # TODO: move dataset type check somewhere else @@ -180,10 +182,9 @@ def write_batch( data = np.concatenate(samples) # write prediction - file_name = self.write_filenames[self.current_file_index] + file_name = next(self.filename_iter) file_path = (dirpath / file_name).with_suffix(self.write_extension) self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) - self.current_file_index += 1 def reset(self) -> None: """ @@ -191,7 +192,7 @@ def reset(self) -> None: Attributes reset are: `write_filenames`, `tile_cache`, and `current_file_index`. """ - self.write_filenames = None - self.current_file_index = 0 + self._write_filenames = None + self.filename_iter = None self.tile_cache.reset() self.sample_cache.reset() diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index 640b065e..222fdffa 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -268,7 +268,7 @@ def test_write_batch_raises(cache_tiles_strategy: WriteTiles): trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset with pytest.raises(ValueError): - assert cache_tiles_strategy.write_filenames is None + assert cache_tiles_strategy._write_filenames is None # call write batch dirpath = Path("predictions") @@ -363,10 +363,10 @@ def test_reset(cache_tiles_strategy: WriteTiles): # don't include last tile patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) - cache_tiles_strategy.write_filenames = ["file"] + cache_tiles_strategy._write_filenames = ["file"] cache_tiles_strategy.current_file_index = 1 cache_tiles_strategy.reset() - assert cache_tiles_strategy.write_filenames is None + assert cache_tiles_strategy._write_filenames is None assert cache_tiles_strategy.current_file_index == 0 assert len(cache_tiles_strategy.tile_cache) == 0 assert len(cache_tiles_strategy.tile_info_cache) == 0 diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py index c0ff91dd..04611863 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py @@ -144,8 +144,8 @@ def test_write_batch_raises(write_image_strategy, ordered_array): def test_reset(write_image_strategy: WriteImage): """Test WriteImage.reset works as expected""" - write_image_strategy.write_filenames = ["file"] + write_image_strategy._write_filenames = ["file"] write_image_strategy.current_file_index = 1 write_image_strategy.reset() - assert write_image_strategy.write_filenames is None + assert write_image_strategy._write_filenames is None assert write_image_strategy.current_file_index == 0 From 7de654d1d8c9a151fefcf0f82ceae6629ccfb236 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Tue, 12 Nov 2024 11:31:49 +0100 Subject: [PATCH 26/38] test(WriteImage strategy): fix tests for updated classes; fix: bugs --- .../write_strategy/write_image.py | 17 ++++++++---- .../prediction_writer_callback/conftest.py | 11 ++++++++ .../test_write_image_write_strategy.py | 27 ++++++++----------- 3 files changed, 34 insertions(+), 21 deletions(-) create mode 100644 tests/lightning/callbacks/prediction_writer_callback/conftest.py diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py index 077d1b81..b5c6a79a 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -64,6 +64,9 @@ def __init__( Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} Extra kwargs to pass to `write_func`. + n_samples_per_file : list of int + The number of samples in each file, (controls which samples will be + grouped together in each file). """ super().__init__() @@ -71,15 +74,19 @@ def __init__( self.write_extension: str = write_extension self.write_func_kwargs: dict[str, Any] = write_func_kwargs - # where samples are stored until a whole file has been predicted - self.sample_cache = SampleCache(n_samples_per_file) - self._write_filenames: Optional[list[str]] = write_filenames self.filename_iter: Optional[Iterator[str]] = ( iter(write_filenames) if write_filenames is not None else None ) - if write_filenames is not None and n_samples_per_file is not None: + + # where samples are stored until a whole file has been predicted + self.sample_cache: Optional[SampleCache] + + if not ((write_filenames is None) or (n_samples_per_file is None)): + # also creates sample cache self.set_file_data(write_filenames, n_samples_per_file) + else: + self.sample_cache = None def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]): if len(write_filenames) != len(n_samples_per_file): @@ -89,7 +96,7 @@ def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int ) self._write_filenames = write_filenames self.filename_iter = iter(write_filenames) - self.sample_cache.n_samples_per_file = n_samples_per_file + self.sample_cache = SampleCache(n_samples_per_file=n_samples_per_file) def write_batch( self, diff --git a/tests/lightning/callbacks/prediction_writer_callback/conftest.py b/tests/lightning/callbacks/prediction_writer_callback/conftest.py new file mode 100644 index 00000000..cee80ffd --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/conftest.py @@ -0,0 +1,11 @@ +from unittest.mock import Mock + +import pytest + +from careamics.file_io import WriteFunc + + +@pytest.fixture +def write_func(): + """Mock `WriteFunc`.""" + return Mock(spec=WriteFunc) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py index 04611863..fb6c1346 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py @@ -9,22 +9,16 @@ from torch.utils.data import DataLoader from careamics.dataset import IterablePredDataset -from careamics.file_io import WriteFunc + from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( WriteImage, ) -@pytest.fixture -def write_func(): - """Mock `WriteFunc`.""" - return Mock(spec=WriteFunc) - - @pytest.fixture def write_image_strategy(write_func) -> WriteImage: """ - Initialized `CacheTiles` class. + Initialized `WriteImage` class. Parameters ---------- @@ -33,20 +27,21 @@ def write_image_strategy(write_func) -> WriteImage: Returns ------- - CacheTiles - Initialized `CacheTiles` class. + WriteImage + Initialized `WriteImage` class. """ write_extension = ".ext" write_func_kwargs = {} return WriteImage( write_func=write_func, - write_filenames=None, write_extension=write_extension, write_func_kwargs=write_func_kwargs, + write_filenames=None, + n_samples_per_file=None, ) -def test_cache_tiles_init(write_func, write_image_strategy): +def test_write_image_init(write_func, write_image_strategy): """ Test `WriteImage` initializes as expected. """ @@ -55,7 +50,7 @@ def test_cache_tiles_init(write_func, write_image_strategy): assert write_image_strategy.write_func_kwargs == {} -def test_write_batch(write_image_strategy, ordered_array): +def test_write_batch(write_image_strategy: WriteImage, ordered_array): n_batches = 1 @@ -79,7 +74,7 @@ def test_write_batch(write_image_strategy, ordered_array): # call write batch dirpath = Path("predictions") - write_image_strategy.write_filenames = ["file"] + write_image_strategy.set_file_data(write_filenames=["file"], n_samples_per_file=[n_batches]) write_image_strategy.write_batch( trainer=trainer, pl_module=Mock(spec=LightningModule), @@ -102,7 +97,7 @@ def test_write_batch(write_image_strategy, ordered_array): ) -def test_write_batch_raises(write_image_strategy, ordered_array): +def test_write_batch_raises(write_image_strategy: WriteImage, ordered_array): """Test write batch raises a ValueError if the filenames have not been set.""" n_batches = 1 @@ -129,7 +124,7 @@ def test_write_batch_raises(write_image_strategy, ordered_array): with pytest.raises(ValueError): # Make sure write_filenames is None - assert write_image_strategy.write_filenames is None + assert write_image_strategy._write_filenames is None write_image_strategy.write_batch( trainer=trainer, pl_module=Mock(spec=LightningModule), From 5f5109d8b8533348960a6c6f37380129932b3a98 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Tue, 12 Nov 2024 11:33:20 +0100 Subject: [PATCH 27/38] test: update prediction callback write strategy unit tests --- .../test_cache_tiles_write_strategy.py | 7 ------- .../test_write_image_write_strategy.py | 5 +++-- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index 222fdffa..d0ea8faf 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -12,7 +12,6 @@ from careamics.config.tile_information import TileInformation from careamics.dataset import IterableTiledPredDataset from careamics.dataset.tiling import extract_tiles -from careamics.file_io import WriteFunc from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( WriteTiles, ) @@ -70,12 +69,6 @@ def patch_tile_cache( strategy.tile_info_cache = tile_infos -@pytest.fixture -def write_func(): - """Mock `WriteFunc`.""" - return Mock(spec=WriteFunc) - - @pytest.fixture def cache_tiles_strategy(write_func) -> WriteTiles: """ diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py index fb6c1346..1f4823c2 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py @@ -9,7 +9,6 @@ from torch.utils.data import DataLoader from careamics.dataset import IterablePredDataset - from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( WriteImage, ) @@ -74,7 +73,9 @@ def test_write_batch(write_image_strategy: WriteImage, ordered_array): # call write batch dirpath = Path("predictions") - write_image_strategy.set_file_data(write_filenames=["file"], n_samples_per_file=[n_batches]) + write_image_strategy.set_file_data( + write_filenames=["file"], n_samples_per_file=[n_batches] + ) write_image_strategy.write_batch( trainer=trainer, pl_module=Mock(spec=LightningModule), From a7793168629dce7366c5bb47319359758ffb85a8 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Fri, 22 Nov 2024 16:30:55 +0100 Subject: [PATCH 28/38] fix: bugs and tests since requiring n_samples_per_file param --- .../write_strategy/protocol.py | 7 +- .../write_strategy/utils.py | 3 +- .../write_strategy/write_image.py | 11 +- .../write_strategy/write_tiles.py | 23 ++- .../write_strategy_factory.py | 14 +- .../test_cache_tiles_write_strategy.py | 143 ++++++++---------- .../test_prediction_writer_callback.py | 14 +- 7 files changed, 117 insertions(+), 98 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py index cb5dcd1b..d84f3b88 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py @@ -43,12 +43,13 @@ def write_batch( Path to directory to save predictions to. """ - def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]) -> None: - ... + def set_file_data( + self, write_filenames: list[str], n_samples_per_file: list[int] + ) -> None: ... def reset(self) -> None: """ Reset internal attributes of a `WriteStrategy` instance. - This is to prevent bugs if a `WriteStrategy` instance is used twice. + This is to unexpected behaviour if a `WriteStrategy` instance is used twice. """ diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py index b7603333..0ba90811 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py @@ -44,10 +44,11 @@ def reset(self): class SampleCache: - def __init__(self, n_samples_per_file: Optional[list[int]]): + def __init__(self, n_samples_per_file: list[int]): self.n_samples_per_file: list[int] = n_samples_per_file self.n_samples_iter = iter(self.n_samples_per_file) + # n_samples will be set to None once iterated through each element self.n_samples: Optional[int] = next(self.n_samples_iter) self.sample_cache: list[NDArray] = [] diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py index b5c6a79a..3d4199bb 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -138,8 +138,15 @@ def write_batch( ValueError If `write_filenames` attribute is `None`. """ - if self._write_filenames is None: - raise ValueError("`write_filenames` attribute has not been set.") + if self.sample_cache is None: + raise ValueError( + "`SampleCache` has not been created. Call `set_file_data` before " + "calling `write_batch`." + ) + # assert for mypy + assert self.filename_iter is not None, ( + "`filename_iter` is `None` should be set by `set_file_data`." + ) dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py index 3bae5d3b..01399838 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -86,7 +86,7 @@ def __init__( # where tiles will be cached until a whole image has been predicted self.tile_cache = TileCache() # where samples are stored until a whole file has been predicted - self.sample_cache = SampleCache(n_samples_per_file) + self.sample_cache: Optional[SampleCache] self._write_filenames: Optional[list[str]] = write_filenames self.filename_iter: Optional[Iterator[str]] = ( @@ -94,6 +94,8 @@ def __init__( ) if write_filenames is not None and n_samples_per_file is not None: self.set_file_data(write_filenames, n_samples_per_file) + else: + self.sample_cache = None def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]): if len(write_filenames) != len(n_samples_per_file): @@ -103,7 +105,7 @@ def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int ) self._write_filenames = write_filenames self.filename_iter = iter(write_filenames) - self.sample_cache.n_samples_per_file = n_samples_per_file + self.sample_cache = SampleCache(n_samples_per_file) def write_batch( self, @@ -143,8 +145,15 @@ def write_batch( ValueError If `write_filenames` attribute is `None`. """ - if self._write_filenames is None: - raise ValueError("`write_filenames` attribute has not been set.") + if self.sample_cache is None: + raise ValueError( + "`SampleCache` has not been created. Call `set_file_data` before " + "calling `write_batch`." + ) + # assert for mypy + assert self.filename_iter is not None, ( + "`filename_iter` is `None` should be set by `set_file_data`." + ) # TODO: move dataset type check somewhere else dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders @@ -194,5 +203,7 @@ def reset(self) -> None: """ self._write_filenames = None self.filename_iter = None - self.tile_cache.reset() - self.sample_cache.reset() + if self.tile_cache is not None: + self.tile_cache.reset() + if self.sample_cache is not None: + self.sample_cache.reset() diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py index a7427a4d..207ac5c5 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py @@ -12,9 +12,10 @@ def create_write_strategy( write_type: SupportedWriteType, tiled: bool, write_func: Optional[WriteFunc] = None, - write_filenames: Optional[list[str]] = None, write_extension: Optional[str] = None, write_func_kwargs: Optional[dict[str, Any]] = None, + write_filenames: Optional[list[str]] = None, + n_samples_per_file: Optional[list[int]] = None, ) -> WriteStrategy: """ Create a write strategy from convenient parameters. @@ -28,13 +29,16 @@ def create_write_strategy( write_func : WriteFunc, optional If a known `write_type` is selected this argument is ignored. For a custom `write_type` a function to save the data must be passed. See notes below. - write_filenames : list of str, optional - A list of filenames in the order that predictions will be written in. write_extension : str, optional If a known `write_type` is selected this argument is ignored. For a custom `write_type` an extension to save the data with must be passed. write_func_kwargs : dict of {str: any}, optional Additional keyword arguments to be passed to the save function. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list of int + The number of samples in each file, (controls which samples will be grouped + together in each file). Returns ------- @@ -67,6 +71,7 @@ def create_write_strategy( write_filenames=write_filenames, write_extension=write_extension, write_func_kwargs=write_func_kwargs, + n_samples_per_file=n_samples_per_file, ) else: # select CacheTiles or WriteTilesZarr (when implemented) @@ -76,6 +81,7 @@ def create_write_strategy( write_filenames=write_filenames, write_extension=write_extension, write_func_kwargs=write_func_kwargs, + n_samples_per_file=n_samples_per_file, ) return write_strategy @@ -87,6 +93,7 @@ def _create_tiled_write_strategy( write_filenames: Optional[list[str]], write_extension: Optional[str], write_func_kwargs: dict[str, Any], + n_samples_per_file: Optional[list[int]], ) -> WriteStrategy: """ Create a tiled write strategy. @@ -135,6 +142,7 @@ def _create_tiled_write_strategy( write_filenames=write_filenames, write_extension=write_extension, write_func_kwargs=write_func_kwargs, + n_samples_per_file=n_samples_per_file, ) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index d0ea8faf..54e2d821 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -15,6 +15,9 @@ from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( WriteTiles, ) +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy.utils import ( + TileCache, +) def create_tiles(n_samples: int) -> tuple[list[NDArray], list[TileInformation]]: @@ -65,8 +68,8 @@ def patch_tile_cache( tile_infos : list of TileInformation Corresponding tile information to patch into `strategy.tile_info_cache`. """ - strategy.tile_cache = tiles - strategy.tile_info_cache = tile_infos + strategy.tile_cache = TileCache() + strategy.tile_cache.add((np.concatenate(tiles), tile_infos)) @pytest.fixture @@ -88,25 +91,14 @@ def cache_tiles_strategy(write_func) -> WriteTiles: write_func_kwargs = {} return WriteTiles( write_func=write_func, - write_filenames=None, write_extension=write_extension, write_func_kwargs=write_func_kwargs, + write_filenames=None, + n_samples_per_file=None, ) -def test_cache_tiles_init(write_func, cache_tiles_strategy): - """ - Test `CacheTiles` initializes as expected. - """ - assert cache_tiles_strategy.write_func is write_func - assert cache_tiles_strategy.write_extension == ".ext" - assert cache_tiles_strategy.write_func_kwargs == {} - assert cache_tiles_strategy.tile_cache == [] - assert cache_tiles_strategy.tile_info_cache == [] - assert cache_tiles_strategy.write_filenames is None - assert cache_tiles_strategy.current_file_index == 0 - - +# TODO: Move to test tile cache def test_last_tiles(cache_tiles_strategy): """Test `CacheTiles.last_tile` property.""" @@ -114,8 +106,12 @@ def test_last_tiles(cache_tiles_strategy): tiles, tile_infos = create_tiles(n_samples=1) patch_tile_cache(cache_tiles_strategy, tiles, tile_infos) - last_tile = [False, False, False, False, False, False, False, False, True] - assert cache_tiles_strategy.last_tiles == last_tile + last_tiles = [False, False, False, False, False, False, False, False, True] + cached_last_tiles = [ + tile_info.last_tile + for tile_info in cache_tiles_strategy.tile_cache.tile_info_cache + ] + assert cached_last_tiles == last_tiles def test_write_batch_no_last_tile(cache_tiles_strategy): @@ -126,7 +122,8 @@ def test_write_batch_no_last_tile(cache_tiles_strategy): """ # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) + n_samples = 1 + tiles, tile_infos = create_tiles(n_samples=n_samples) # simulate adding a batch that will not contain the last tile n_tiles = 4 @@ -144,7 +141,9 @@ def test_write_batch_no_last_tile(cache_tiles_strategy): trainer.predict_dataloaders = [Mock(spec=DataLoader)] trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset - cache_tiles_strategy.write_filenames = ["file_1"] + cache_tiles_strategy.set_file_data( + write_filenames=["file_1.tiff"], n_samples_per_file=[n_samples] + ) cache_tiles_strategy.write_batch( trainer=trainer, pl_module=Mock(spec=LightningModule), @@ -160,10 +159,12 @@ def test_write_batch_no_last_tile(cache_tiles_strategy): extended_tile_infos = tile_infos[: n_tiles + batch_size] assert all( - np.array_equal(extended_tiles[i], cache_tiles_strategy.tile_cache[i]) + np.array_equal( + extended_tiles[i], cache_tiles_strategy.tile_cache.array_cache[i] + ) for i in range(n_tiles + batch_size) ) - assert extended_tile_infos == cache_tiles_strategy.tile_info_cache + assert extended_tile_infos == cache_tiles_strategy.tile_cache.tile_info_cache def test_write_batch_last_tile(cache_tiles_strategy): @@ -174,7 +175,8 @@ def test_write_batch_last_tile(cache_tiles_strategy): """ # all tiles of 2 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=2) + n_samples = 2 + tiles, tile_infos = create_tiles(n_samples=n_samples) # simulate adding a batch that will contain the last tile n_tiles = 8 @@ -195,21 +197,22 @@ def test_write_batch_last_tile(cache_tiles_strategy): trainer.predict_dataloaders = [Mock(spec=DataLoader)] trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset - # These functions have their own unit tests, - # so they do not need to be tested again here. - # This is a unit test to isolate functionality of `write_batch.` + file_names = [f"file_{i}" for i in range(n_samples)] + n_samples_per_file = [1 for _ in range(n_samples)] + + # call write batch + dirpath = Path("predictions") + cache_tiles_strategy.set_file_data( + write_filenames=file_names, n_samples_per_file=n_samples_per_file + ) + + # mocking stitch_prediction_single because to assert if WriteFunc was called with patch( "careamics.lightning.callbacks.prediction_writer_callback.write_strategy" - + ".cache_tiles.stitch_prediction_single", + + ".write_tiles.stitch_prediction_single", ) as mock_stitch_prediction_single: - - prediction_image = [Mock()] - file_name = "file" - mock_stitch_prediction_single.return_value = prediction_image - - # call write batch - dirpath = Path("predictions") - cache_tiles_strategy.write_filenames = [file_name] + mock_prediction = np.arange(64).reshape(1, 1, 8, 8) + mock_stitch_prediction_single.return_value = mock_prediction cache_tiles_strategy.write_batch( trainer=trainer, pl_module=Mock(spec=LightningModule), @@ -221,19 +224,21 @@ def test_write_batch_last_tile(cache_tiles_strategy): dirpath=dirpath, ) - # assert write_func is called as expected - cache_tiles_strategy.write_func.assert_called_once_with( - file_path=Path("predictions/file.ext"), img=prediction_image[0], **{} - ) + # assert write_func is called as expected + write_func_call_args = cache_tiles_strategy.write_func.call_args.kwargs + assert write_func_call_args["file_path"] == Path("predictions/file_0.ext") + np.testing.assert_array_equal(write_func_call_args["img"], mock_prediction) # Tile of the next image (should remain in the cache) remaining_tile = tiles[9] remaining_tile_info = tile_infos[9] # assert cache cleared as expected - assert len(cache_tiles_strategy.tile_cache) == 1 - assert np.array_equal(remaining_tile, cache_tiles_strategy.tile_cache[0]) - assert remaining_tile_info == cache_tiles_strategy.tile_info_cache[0] + assert len(cache_tiles_strategy.tile_cache.array_cache) == 1 + assert np.array_equal( + remaining_tile, cache_tiles_strategy.tile_cache.array_cache[0] + ) + assert remaining_tile_info == cache_tiles_strategy.tile_cache.tile_info_cache[0] def test_write_batch_raises(cache_tiles_strategy: WriteTiles): @@ -277,6 +282,7 @@ def test_write_batch_raises(cache_tiles_strategy: WriteTiles): ) +# TODO: move to tile cache tests def test_have_last_tile_true(cache_tiles_strategy): """Test `CacheTiles._have_last_tile` returns true when there is a last tile.""" @@ -284,7 +290,7 @@ def test_have_last_tile_true(cache_tiles_strategy): tiles, tile_infos = create_tiles(n_samples=1) patch_tile_cache(cache_tiles_strategy, tiles, tile_infos) - assert cache_tiles_strategy._has_last_tile() + assert cache_tiles_strategy.tile_cache.has_last_tile() def test_have_last_tile_false(cache_tiles_strategy): @@ -295,10 +301,11 @@ def test_have_last_tile_false(cache_tiles_strategy): # don't include last tile patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) - assert not cache_tiles_strategy._has_last_tile() + assert not cache_tiles_strategy.tile_cache.has_last_tile() -def test_clear_cache(cache_tiles_strategy): +# TODO: move to test tile cache +def test_pop_image_tiles(cache_tiles_strategy): """ Test `CacheTiles._clear_cache` removes the tiles up until the first "last tile". """ @@ -308,23 +315,19 @@ def test_clear_cache(cache_tiles_strategy): # include first tile from next sample patch_tile_cache(cache_tiles_strategy, tiles[:10], tile_infos[:10]) - cache_tiles_strategy._clear_cache() + image_tiles, image_tile_infos = cache_tiles_strategy.tile_cache.pop_image_tiles() - assert len(cache_tiles_strategy.tile_cache) == 1 - assert np.array_equal(cache_tiles_strategy.tile_cache[0], tiles[9]) - assert cache_tiles_strategy.tile_info_cache[0] == tile_infos[9] + assert len(cache_tiles_strategy.tile_cache.array_cache) == 1 + assert np.array_equal(cache_tiles_strategy.tile_cache.array_cache[0], tiles[9]) + assert cache_tiles_strategy.tile_cache.tile_info_cache[0] == tile_infos[9] - -def test_last_tile_index(cache_tiles_strategy): - """Test `CacheTiles._last_tile_index` returns the index of the last tile.""" - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - patch_tile_cache(cache_tiles_strategy, tiles, tile_infos) - - assert cache_tiles_strategy._last_tile_index() == 8 + assert len(image_tiles) == 9 + assert all(np.array_equal(image_tiles[i], tiles[i]) for i in range(9)) + assert image_tile_infos == tile_infos[:9] -def test_last_tile_index_error(cache_tiles_strategy): +# TODO: move to test tile cache +def test_pop_image_tiles_error(cache_tiles_strategy: WriteTiles): """Test `CacheTiles._last_tile_index` raises an error when there is no last tile.""" # all tiles of 1 samples with 9 tiles tiles, tile_infos = create_tiles(n_samples=1) @@ -332,21 +335,7 @@ def test_last_tile_index_error(cache_tiles_strategy): patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) with pytest.raises(ValueError): - cache_tiles_strategy._last_tile_index() - - -def test_get_image_tiles(cache_tiles_strategy): - """Test `CacheTiles._get_image_tiles` returns the tiles of a single image.""" - # all tiles of 2 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=2) - # include first tile from next sample - patch_tile_cache(cache_tiles_strategy, tiles[:10], tile_infos[:10]) - - image_tiles, image_tile_infos = cache_tiles_strategy._get_image_tiles() - - assert len(image_tiles) == 9 - assert all(np.array_equal(image_tiles[i], tiles[i]) for i in range(9)) - assert image_tile_infos == tile_infos[:9] + cache_tiles_strategy.tile_cache.pop_image_tiles() def test_reset(cache_tiles_strategy: WriteTiles): @@ -356,10 +345,8 @@ def test_reset(cache_tiles_strategy: WriteTiles): # don't include last tile patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) - cache_tiles_strategy._write_filenames = ["file"] - cache_tiles_strategy.current_file_index = 1 + cache_tiles_strategy.set_file_data(write_filenames=["file"], n_samples_per_file=[1]) cache_tiles_strategy.reset() + assert cache_tiles_strategy._write_filenames is None - assert cache_tiles_strategy.current_file_index == 0 - assert len(cache_tiles_strategy.tile_cache) == 0 - assert len(cache_tiles_strategy.tile_info_cache) == 0 + assert cache_tiles_strategy.filename_iter is None diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py index e8119111..48215b75 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py @@ -86,7 +86,10 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): # create prediction writer callback params write_strategy = create_write_strategy( - write_type="tiff", tiled=True, write_filenames=[file_name] + write_type="tiff", + tiled=True, + write_filenames=[file_name], + n_samples_per_file=[1], ) write_strategy.reset = MagicMock(side_effect=write_strategy.reset) dirpath = tmp_path / "predictions" @@ -125,7 +128,6 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): # filenames reset after predictions called write_strategy.reset.assert_called_once() - assert write_strategy.write_filenames is None # assert predicted file exists assert (dirpath / file_name).is_file() @@ -133,7 +135,7 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): # open file save_data = tifffile.imread(dirpath / file_name) # save data has singleton channel axis - np.testing.assert_array_equal(save_data, predicted_images[0][0], verbose=True) + np.testing.assert_array_equal(save_data, predicted_images[0], verbose=True) def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): @@ -171,7 +173,10 @@ def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): # create prediction writer callback params write_strategy = create_write_strategy( - write_type="tiff", tiled=False, write_filenames=[file_name] + write_type="tiff", + tiled=False, + write_filenames=[file_name], + n_samples_per_file=[1], ) write_strategy.reset = MagicMock(side_effect=write_strategy.reset) dirpath = tmp_path / "predictions" @@ -208,7 +213,6 @@ def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): # filenames reset after predictions called write_strategy.reset.assert_called_once() - assert write_strategy.write_filenames is None # assert predicted file exists assert (dirpath / file_name).is_file() From c5305e3c3cda76087daa427ba211c239dfa7b210 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Fri, 22 Nov 2024 17:38:34 +0100 Subject: [PATCH 29/38] docs: add and update docs --- .../write_strategy/protocol.py | 25 ++++- .../write_strategy/utils.py | 95 ++++++++++++++++++- .../write_strategy/write_image.py | 42 ++++---- .../write_strategy/write_tiles.py | 44 ++++++--- .../write_strategy_factory.py | 21 ++-- 5 files changed, 183 insertions(+), 44 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py index d84f3b88..ded4376e 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/protocol.py @@ -7,7 +7,14 @@ class WriteStrategy(Protocol): - """Protocol for write strategy classes.""" + """ + Protocol for write strategy classes. + + A `WriteStrategy` is an object that will be an attribute in the + `PredictionWriterCallback`; it will determine how predictions will be saved. + `WriteStrategy`s must be interchangeable so they must follow the interface set out + in this `Protocol` class. + """ def write_batch( self, @@ -45,11 +52,23 @@ def write_batch( def set_file_data( self, write_filenames: list[str], n_samples_per_file: list[int] - ) -> None: ... + ) -> None: + """ + Set file information after the `WriteStrategy` has been initialized. + + Parameters + ---------- + write_filenames : list[str] + A list of filenames to save to. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). + """ def reset(self) -> None: """ - Reset internal attributes of a `WriteStrategy` instance. + Reset internal state (attributes) of a `WriteStrategy` instance. This is to unexpected behaviour if a `WriteStrategy` instance is used twice. """ diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py index 0ba90811..1cf1d642 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py @@ -1,3 +1,5 @@ +"""Utility classes and funtions used in the write strategies.""" + from typing import Optional import numpy as np @@ -8,21 +10,65 @@ class TileCache: """ - Cache tiles; logic to pop tiles when tiles from a full image have been stored. + Logic to cache tiles, then pop tiles when tiles from a full image have been stored. + + Attributes + ---------- + array_cache : list[numpy.ndarray] + The tile arrays with the dimensions SC(Z)YX. + tile_info_cache : list[TileInformation] + The corresponding tile information for each tile. """ def __init__(self): + """Logic to cache tiles, and pop tiles when a full set of have been stored.""" self.array_cache: list[NDArray] = [] self.tile_info_cache: list[TileInformation] = [] def add(self, item: tuple[NDArray, list[TileInformation]]): + """ + Add another batch to the cache. + + Parameters + ---------- + item : tuple of (numpy.ndarray, list[TileInformation]) + Tuple where the first element is a concatenated set of tiles, and the + second element is a list of each corresponding `TileInformation`. + """ self.array_cache.extend(np.split(item[0], item[0].shape[0])) self.tile_info_cache.extend(item[1]) def has_last_tile(self) -> bool: + """ + Determine whether the current cache contains the last tile of a sample. + + Returns + ------- + bool + Whether the last tile is contained in the cache. + """ return any(tile_info.last_tile for tile_info in self.tile_info_cache) def pop_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]: + """ + Pop the tiles that will create a full image from the cache. + + I.e. The tiles belonging to a full image will be removed from the cache but + returned by this function call. + + Returns + ------- + list of numpy.ndarray + A list of tiles with the dimensions SC(Z)YX. + list of TileInformation + A list of corresponding tile information. + + Raises + ------ + ValueError + If the tiles belonging to a full image are not contained in the cache, i.e. + if the cache does not contain the last tile of an image. + """ is_last_tile = [tile_info.last_tile for tile_info in self.tile_info_cache] if not any(is_last_tile): raise ValueError("No last tile in cache.") @@ -38,14 +84,34 @@ def pop_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]: return tiles, tile_infos def reset(self): + """Reset the cache. Remove all tiles and tile information from the cache.""" self.array_cache = [] self.tile_info_cache = [] class SampleCache: + """ + Logic to cache samples until they can be concatenated together to create a file. + + Parameters + ---------- + n_samples_per_file : list[int] + A list that contains the number of samples that will be contained in each + file. There should be `n` elements in the list for `n` files intended to + be created. + """ def __init__(self, n_samples_per_file: list[int]): - + """ + Logic to cache samples until they can be concatenated together to create a file. + + Parameters + ---------- + n_samples_per_file : list[int] + A list that contains the number of samples that will be contained in each + file. There should be `n` elements in the list for `n` files intended to + be created. + """ self.n_samples_per_file: list[int] = n_samples_per_file self.n_samples_iter = iter(self.n_samples_per_file) # n_samples will be set to None once iterated through each element @@ -53,9 +119,25 @@ def __init__(self, n_samples_per_file: list[int]): self.sample_cache: list[NDArray] = [] def add(self, item: NDArray): + """ + Add a sample to the cache. + + Parameters + ---------- + item : numpy.ndarray + A single predicted sample. + """ self.sample_cache.extend(np.split(item, item.shape[0])) def has_all_file_samples(self) -> bool: + """ + Determine if all the samples for the current file are contained in the cache. + + Returns + ------- + bool + Whether all the samples are contained in the cache. + """ if self.n_samples is None: raise ValueError( "Number of samples for current file is unknown. Reached the end of the " @@ -64,6 +146,14 @@ def has_all_file_samples(self) -> bool: return len(self.sample_cache) >= self.n_samples def pop_file_samples(self) -> list[NDArray]: + """ + Pop from the cache the samples required for the current file to be created. + + Returns + ------- + list of numpy.ndarray + A list of samples to concatenate together into a file. + """ if not self.has_all_file_samples(): raise ValueError( "Do not have all the samples belonging to the current file." @@ -80,5 +170,6 @@ def pop_file_samples(self) -> list[NDArray]: return samples def reset(self): + """Reset the cache. Remove all the samples from the cache.""" self.n_samples_iter = iter(self.n_samples_per_file) self.sample_cache: list[NDArray] = [] diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py index 3d4199bb..33ed78b6 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -16,18 +16,21 @@ class WriteImage: """ - A strategy for writing image predictions (i.e. un-tiled predictions). + A strategy for writing image predictions (i.e. not tiled predictions). Parameters ---------- write_func : WriteFunc Function used to save predictions. - write_filenames : list of str, optional - A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} Extra kwargs to pass to `write_func`. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list of int + The number of samples in each file, (controls which samples will be + grouped together in each file). Attributes ---------- @@ -39,8 +42,7 @@ class WriteImage: Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} Extra kwargs to pass to `write_func`. - current_file_index : int - Index of current file, increments every time a file is written. + """ def __init__( @@ -58,12 +60,12 @@ def __init__( ---------- write_func : WriteFunc Function used to save predictions. - write_filenames : list of str, optional - A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} Extra kwargs to pass to `write_func`. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. n_samples_per_file : list of int The number of samples in each file, (controls which samples will be grouped together in each file). @@ -89,6 +91,18 @@ def __init__( self.sample_cache = None def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]): + """ + Set file information after the `WriteImage` strategy has been initialized. + + Parameters + ---------- + write_filenames : list[str] + A list of filenames to save to. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). + """ if len(write_filenames) != len(n_samples_per_file): raise ValueError( "List of filename and list of number of samples per file are not of " @@ -141,12 +155,12 @@ def write_batch( if self.sample_cache is None: raise ValueError( "`SampleCache` has not been created. Call `set_file_data` before " - "calling `write_batch`." + "calling `write_batch`." ) # assert for mypy - assert self.filename_iter is not None, ( - "`filename_iter` is `None` should be set by `set_file_data`." - ) + assert ( + self.filename_iter is not None + ), "`filename_iter` is `None` should be set by `set_file_data`." dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls @@ -172,11 +186,7 @@ def write_batch( self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) def reset(self) -> None: - """ - Reset internal attributes. - - Resets the `write_filenames` and `current_file_index` attributes. - """ + """Reset internal attributes.""" self._write_filenames = None self.filename_iter = None self.current_file_index = 0 diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py index 01399838..5607caee 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -27,12 +27,16 @@ class WriteTiles: ---------- write_func : WriteFunc Function used to save predictions. - write_filenames : list of str, optional - A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} Extra kwargs to pass to `write_func`. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). Attributes ---------- @@ -48,8 +52,6 @@ class WriteTiles: Tiles cached for stitching prediction. tile_info_cache : list of TileInformation Cached tile information for stitching prediction. - current_file_index : int - Index of current file, increments every time a file is written. """ def __init__( @@ -70,12 +72,16 @@ def __init__( ---------- write_func : WriteFunc Function used to save predictions. - write_filenames : list of str, optional - A list of filenames in the order that predictions will be written in. write_extension : str Extension added to prediction file paths. write_func_kwargs : dict of {str: Any} Extra kwargs to pass to `write_func`. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). """ super().__init__() @@ -98,6 +104,18 @@ def __init__( self.sample_cache = None def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int]): + """ + Set file information after the `WriteTiles` has been initialized. + + Parameters + ---------- + write_filenames : list[str] + A list of filenames to save to. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). + """ if len(write_filenames) != len(n_samples_per_file): raise ValueError( "List of filename and list of number of samples per file are not of " @@ -148,12 +166,12 @@ def write_batch( if self.sample_cache is None: raise ValueError( "`SampleCache` has not been created. Call `set_file_data` before " - "calling `write_batch`." + "calling `write_batch`." ) # assert for mypy - assert self.filename_iter is not None, ( - "`filename_iter` is `None` should be set by `set_file_data`." - ) + assert ( + self.filename_iter is not None + ), "`filename_iter` is `None` should be set by `set_file_data`." # TODO: move dataset type check somewhere else dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders @@ -196,11 +214,7 @@ def write_batch( self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) def reset(self) -> None: - """ - Reset the internal attributes. - - Attributes reset are: `write_filenames`, `tile_cache`, and `current_file_index`. - """ + """Reset the internal attributes.""" self._write_filenames = None self.filename_iter = None if self.tile_cache is not None: diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py index 207ac5c5..771bc7e3 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py @@ -36,9 +36,10 @@ def create_write_strategy( Additional keyword arguments to be passed to the save function. write_filenames : list of str, optional A list of filenames in the order that predictions will be written in. - n_samples_per_file : list of int - The number of samples in each file, (controls which samples will be grouped - together in each file). + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). Returns ------- @@ -74,7 +75,7 @@ def create_write_strategy( n_samples_per_file=n_samples_per_file, ) else: - # select CacheTiles or WriteTilesZarr (when implemented) + # select WriteTiles or WriteTilesZarr (when implemented) write_strategy = _create_tiled_write_strategy( write_type=write_type, write_func=write_func, @@ -90,15 +91,15 @@ def create_write_strategy( def _create_tiled_write_strategy( write_type: SupportedWriteType, write_func: Optional[WriteFunc], - write_filenames: Optional[list[str]], write_extension: Optional[str], write_func_kwargs: dict[str, Any], + write_filenames: Optional[list[str]], n_samples_per_file: Optional[list[int]], ) -> WriteStrategy: """ Create a tiled write strategy. - Either `CacheTiles` for caching tiles until a whole image is predicted or + Either `WriteTiles` for caching tiles until a whole image is predicted or `WriteTilesZarr` for writing tiles directly to disk. Parameters @@ -108,13 +109,17 @@ def _create_tiled_write_strategy( write_func : WriteFunc, optional If a known `write_type` is selected this argument is ignored. For a custom `write_type` a function to save the data must be passed. See notes below. - write_filenames : list of str, optional - A list of filenames in the order that predictions will be written in. write_extension : str, optional If a known `write_type` is selected this argument is ignored. For a custom `write_type` an extension to save the data with must be passed. write_func_kwargs : dict of {str: any} Additional keyword arguments to be passed to the save function. + write_filenames : list of str, optional + A list of filenames in the order that predictions will be written in. + n_samples_per_file : list[int] + The number of samples that will be saved within each file. Each element in + the list will correspond to the equivelant file in `write_filenames`. + (Should most likely mirror the input file structure). Returns ------- From 481c5271fbed93ea500a3860b0f6b4b37d1d01e4 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Fri, 29 Nov 2024 10:06:44 +0100 Subject: [PATCH 30/38] refac(prediction writer): rename utils.py to caches.py --- .../write_strategy/{utils.py => caches.py} | 2 +- .../write_strategy/write_image.py | 2 +- .../write_strategy/write_tiles.py | 2 +- .../test_cache_tiles_write_strategy.py | 6 ++---- 4 files changed, 5 insertions(+), 7 deletions(-) rename src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/{utils.py => caches.py} (98%) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py similarity index 98% rename from src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py rename to src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py index 1cf1d642..b17b70a6 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/utils.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py @@ -1,4 +1,4 @@ -"""Utility classes and funtions used in the write strategies.""" +"""Utility classes, for caching data, used in the write strategies.""" from typing import Optional diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py index 33ed78b6..86150e13 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -11,7 +11,7 @@ from careamics.dataset import IterablePredDataset from careamics.file_io import WriteFunc -from .utils import SampleCache +from .caches import SampleCache class WriteImage: diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py index 5607caee..52055f40 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -13,7 +13,7 @@ from careamics.file_io import WriteFunc from careamics.prediction_utils import stitch_prediction_single -from .utils import SampleCache, TileCache +from .caches import SampleCache, TileCache class WriteTiles: diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index 54e2d821..0df45296 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -14,9 +14,7 @@ from careamics.dataset.tiling import extract_tiles from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( WriteTiles, -) -from careamics.lightning.callbacks.prediction_writer_callback.write_strategy.utils import ( - TileCache, + caches, ) @@ -68,7 +66,7 @@ def patch_tile_cache( tile_infos : list of TileInformation Corresponding tile information to patch into `strategy.tile_info_cache`. """ - strategy.tile_cache = TileCache() + strategy.tile_cache = caches.TileCache() strategy.tile_cache.add((np.concatenate(tiles), tile_infos)) From 12de6d679e28031bc181e4b3098c57c99374fa94 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Fri, 29 Nov 2024 12:29:02 +0100 Subject: [PATCH 31/38] test: placeholder funcs for cache tests; move existing tests relating to caches --- .../test_cache_tiles_write_strategy.py | 128 ++---------------- .../test_sample_cache.py | 25 ++++ .../test_tile_cache.py | 78 +++++++++++ .../prediction_writer_callback/utils.py | 63 +++++++++ 4 files changed, 178 insertions(+), 116 deletions(-) create mode 100644 tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py create mode 100644 tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py create mode 100644 tests/lightning/callbacks/prediction_writer_callback/utils.py diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index 0df45296..aab22bfc 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -5,69 +5,15 @@ import numpy as np import pytest -from numpy.typing import NDArray from pytorch_lightning import LightningModule, Trainer from torch.utils.data import DataLoader -from careamics.config.tile_information import TileInformation from careamics.dataset import IterableTiledPredDataset -from careamics.dataset.tiling import extract_tiles from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( WriteTiles, - caches, ) - -def create_tiles(n_samples: int) -> tuple[list[NDArray], list[TileInformation]]: - """ - Create a set of tiles from `n_samples`. - - To create the tiles the following parameters, `tile_size=(4, 4)` and - `overlaps=(2, 2)`, on an input array with shape (`n_samples`, 1, 8, 8); this - results in 9 tiles per sample. - - Parameters - ---------- - n_samples : int - Number of samples to simulate the tiles from. - - Returns - ------- - tuple of (list of NDArray), list of TileInformation)) - Tuple where first element is the list of tiles and second element is a list - of corresponding tile information. - """ - - input_shape = (n_samples, 1, 8, 8) - tile_size = (4, 4) - tile_overlap = (2, 2) - - arr = np.arange(np.prod(input_shape)).reshape(input_shape) - - all_tiles = list(extract_tiles(arr, tile_size, tile_overlap)) - tiles = [output[0] for output in all_tiles] - tile_infos = [output[1] for output in all_tiles] - - return tiles, tile_infos - - -def patch_tile_cache( - strategy: WriteTiles, tiles: list[NDArray], tile_infos: list[TileInformation] -) -> None: - """ - Patch simulated tile cache into `strategy`. - - Parameters - ---------- - strategy : CacheTiles - Write strategy `CacheTiles`. - tiles : list of NDArray - Tiles to patch into `strategy.tile_cache`. - tile_infos : list of TileInformation - Corresponding tile information to patch into `strategy.tile_info_cache`. - """ - strategy.tile_cache = caches.TileCache() - strategy.tile_cache.add((np.concatenate(tiles), tile_infos)) +from .utils import create_tiles, patch_tile_cache @pytest.fixture @@ -102,7 +48,7 @@ def test_last_tiles(cache_tiles_strategy): # all tiles of 1 samples with 9 tiles tiles, tile_infos = create_tiles(n_samples=1) - patch_tile_cache(cache_tiles_strategy, tiles, tile_infos) + patch_tile_cache(cache_tiles_strategy.tile_cache, tiles, tile_infos) last_tiles = [False, False, False, False, False, False, False, False, True] cached_last_tiles = [ @@ -126,7 +72,9 @@ def test_write_batch_no_last_tile(cache_tiles_strategy): # simulate adding a batch that will not contain the last tile n_tiles = 4 batch_size = 2 - patch_tile_cache(cache_tiles_strategy, tiles[:n_tiles], tile_infos[:n_tiles]) + patch_tile_cache( + cache_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] + ) next_batch = ( np.concatenate(tiles[n_tiles : n_tiles + batch_size]), tile_infos[n_tiles : n_tiles + batch_size], @@ -179,7 +127,9 @@ def test_write_batch_last_tile(cache_tiles_strategy): # simulate adding a batch that will contain the last tile n_tiles = 8 batch_size = 2 - patch_tile_cache(cache_tiles_strategy, tiles[:n_tiles], tile_infos[:n_tiles]) + patch_tile_cache( + cache_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] + ) next_batch = ( np.concatenate(tiles[n_tiles : n_tiles + batch_size]), tile_infos[n_tiles : n_tiles + batch_size], @@ -247,7 +197,9 @@ def test_write_batch_raises(cache_tiles_strategy: WriteTiles): # simulate adding a batch that will contain the last tile n_tiles = 8 batch_size = 2 - patch_tile_cache(cache_tiles_strategy, tiles[:n_tiles], tile_infos[:n_tiles]) + patch_tile_cache( + cache_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] + ) next_batch = ( np.concatenate(tiles[n_tiles : n_tiles + batch_size]), tile_infos[n_tiles : n_tiles + batch_size], @@ -280,68 +232,12 @@ def test_write_batch_raises(cache_tiles_strategy: WriteTiles): ) -# TODO: move to tile cache tests -def test_have_last_tile_true(cache_tiles_strategy): - """Test `CacheTiles._have_last_tile` returns true when there is a last tile.""" - - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - patch_tile_cache(cache_tiles_strategy, tiles, tile_infos) - - assert cache_tiles_strategy.tile_cache.has_last_tile() - - -def test_have_last_tile_false(cache_tiles_strategy): - """Test `CacheTiles._have_last_tile` returns false when there is not a last tile.""" - - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - # don't include last tile - patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) - - assert not cache_tiles_strategy.tile_cache.has_last_tile() - - -# TODO: move to test tile cache -def test_pop_image_tiles(cache_tiles_strategy): - """ - Test `CacheTiles._clear_cache` removes the tiles up until the first "last tile". - """ - - # all tiles of 2 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=2) - # include first tile from next sample - patch_tile_cache(cache_tiles_strategy, tiles[:10], tile_infos[:10]) - - image_tiles, image_tile_infos = cache_tiles_strategy.tile_cache.pop_image_tiles() - - assert len(cache_tiles_strategy.tile_cache.array_cache) == 1 - assert np.array_equal(cache_tiles_strategy.tile_cache.array_cache[0], tiles[9]) - assert cache_tiles_strategy.tile_cache.tile_info_cache[0] == tile_infos[9] - - assert len(image_tiles) == 9 - assert all(np.array_equal(image_tiles[i], tiles[i]) for i in range(9)) - assert image_tile_infos == tile_infos[:9] - - -# TODO: move to test tile cache -def test_pop_image_tiles_error(cache_tiles_strategy: WriteTiles): - """Test `CacheTiles._last_tile_index` raises an error when there is no last tile.""" - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - # don't include last tile - patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) - - with pytest.raises(ValueError): - cache_tiles_strategy.tile_cache.pop_image_tiles() - - def test_reset(cache_tiles_strategy: WriteTiles): """Test CacheTiles.reset works as expected""" # all tiles of 1 samples with 9 tiles tiles, tile_infos = create_tiles(n_samples=1) # don't include last tile - patch_tile_cache(cache_tiles_strategy, tiles[:-1], tile_infos[:-1]) + patch_tile_cache(cache_tiles_strategy.tile_cache, tiles[:-1], tile_infos[:-1]) cache_tiles_strategy.set_file_data(write_filenames=["file"], n_samples_per_file=[1]) cache_tiles_strategy.reset() diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py b/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py new file mode 100644 index 00000000..234600ee --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py @@ -0,0 +1,25 @@ +"""Test the utility `SampleCache` class used by the WriteTile classes.""" + + +def test_add(): ... + + +def test_has_all_file_samples_true(): ... + + +def test_has_all_file_samples_false(): ... + + +def test_has_all_file_samples_error(): ... + + +def test_pop_file_samples(): ... + + +def test_pop_file_samples_last_sample(): ... + + +def test_pop_file_samples_error(): ... + + +def test_reset(): ... diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py b/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py new file mode 100644 index 00000000..988aabe2 --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py @@ -0,0 +1,78 @@ +"""Test the utility `TileCache` class used by the WriteTile classes.""" + +import numpy as np +import pytest + +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + caches, +) + +from .utils import create_tiles, patch_tile_cache + + +def test_add(): ... + + +def test_has_last_tile_true(): + """ + Test `TileCache.has_last_tile` returns true when there is a last tile. + """ + + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + + patch_tile_cache(tile_cache, tiles, tile_infos) + assert tile_cache.has_last_tile() + + +def test_has_last_tile_false(): + """Test `TileCache.has_last_tile` returns false when there is not a last tile.""" + + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + # don't include last tile + patch_tile_cache(tile_cache, tiles[:-1], tile_infos[:-1]) + + assert not tile_cache.has_last_tile() + + +def test_pop_image_tiles(): + """ + Test `TileCache.has_last_tile` removes the tiles up until the first "last tile". + """ + + tile_cache = caches.TileCache() + # all tiles of 2 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=2) + # include first tile from next sample + patch_tile_cache(tile_cache, tiles[:10], tile_infos[:10]) + + image_tiles, image_tile_infos = tile_cache.pop_image_tiles() + + # test popped tiles as expected + assert len(image_tiles) == 9 + assert all(np.array_equal(image_tiles[i], tiles[i]) for i in range(9)) + assert image_tile_infos == tile_infos[:9] + + # test tiles remaining in cache as expected + assert len(tile_cache.array_cache) == 1 + assert np.array_equal(tile_cache.array_cache[0], tiles[9]) + assert tile_cache.tile_info_cache[0] == tile_infos[9] + + +def test_pop_image_tiles_error(): + """Test `CacheTiles._last_tile_index` raises an error when there is no last tile.""" + + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + # don't include last tile + patch_tile_cache(tile_cache, tiles[:-1], tile_infos[:-1]) + # should raise error if pop_image_tiles is called without last tile present + with pytest.raises(ValueError): + tile_cache.pop_image_tiles() + + +def test_reset(): ... diff --git a/tests/lightning/callbacks/prediction_writer_callback/utils.py b/tests/lightning/callbacks/prediction_writer_callback/utils.py new file mode 100644 index 00000000..8cc4047d --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/utils.py @@ -0,0 +1,63 @@ +"""Utility functions to be used in tests relating to the `PredictionWriterCallback`.""" + +import numpy as np +from numpy.typing import NDArray + +from careamics.config.tile_information import TileInformation +from careamics.dataset.tiling import extract_tiles +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + caches, +) + + +def patch_tile_cache( + tile_cache: caches.TileCache, + tiles: list[NDArray], + tile_infos: list[TileInformation], +) -> None: + """ + Patch simulated tile cache into `strategy`. + + Parameters + ---------- + tile_cache : TileCache + Tile cache used in `WriteTiles` write strategy class. + tiles : list of NDArray + Tiles to patch into `strategy.tile_cache`. + tile_infos : list of TileInformation + Corresponding tile information to patch into `strategy.tile_info_cache`. + """ + tile_cache.add((np.concatenate(tiles), tile_infos)) + + +def create_tiles(n_samples: int) -> tuple[list[NDArray], list[TileInformation]]: + """ + Create a set of tiles from `n_samples`. + + To create the tiles the following parameters, `tile_size=(4, 4)` and + `overlaps=(2, 2)`, on an input array with shape (`n_samples`, 1, 8, 8); this + results in 9 tiles per sample. + + Parameters + ---------- + n_samples : int + Number of samples to simulate the tiles from. + + Returns + ------- + tuple of (list of NDArray), list of TileInformation)) + Tuple where first element is the list of tiles and second element is a list + of corresponding tile information. + """ + + input_shape = (n_samples, 1, 8, 8) + tile_size = (4, 4) + tile_overlap = (2, 2) + + arr = np.arange(np.prod(input_shape)).reshape(input_shape) + + all_tiles = list(extract_tiles(arr, tile_size, tile_overlap)) + tiles = [output[0] for output in all_tiles] + tile_infos = [output[1] for output in all_tiles] + + return tiles, tile_infos From f05af8ab8f1e03fdf31fc10cf808c1f59375e0b0 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Fri, 29 Nov 2024 15:43:33 +0100 Subject: [PATCH 32/38] style: make some attributes private --- .../write_strategy/write_image.py | 16 ++++++++-------- .../write_strategy/write_tiles.py | 10 +++++----- .../test_cache_tiles_write_strategy.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py index 86150e13..bdd44834 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_image.py @@ -76,14 +76,14 @@ def __init__( self.write_extension: str = write_extension self.write_func_kwargs: dict[str, Any] = write_func_kwargs + # where samples are stored until a whole file has been predicted + self.sample_cache: Optional[SampleCache] + self._write_filenames: Optional[list[str]] = write_filenames - self.filename_iter: Optional[Iterator[str]] = ( + self._filename_iter: Optional[Iterator[str]] = ( iter(write_filenames) if write_filenames is not None else None ) - # where samples are stored until a whole file has been predicted - self.sample_cache: Optional[SampleCache] - if not ((write_filenames is None) or (n_samples_per_file is None)): # also creates sample cache self.set_file_data(write_filenames, n_samples_per_file) @@ -109,7 +109,7 @@ def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int "equal length." ) self._write_filenames = write_filenames - self.filename_iter = iter(write_filenames) + self._filename_iter = iter(write_filenames) self.sample_cache = SampleCache(n_samples_per_file=n_samples_per_file) def write_batch( @@ -159,7 +159,7 @@ def write_batch( ) # assert for mypy assert ( - self.filename_iter is not None + self._filename_iter is not None ), "`filename_iter` is `None` should be set by `set_file_data`." dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders @@ -181,12 +181,12 @@ def write_batch( data = np.concatenate(samples) # write prediction - file_name = next(self.filename_iter) + file_name = next(self._filename_iter) file_path = (dirpath / file_name).with_suffix(self.write_extension) self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) def reset(self) -> None: """Reset internal attributes.""" self._write_filenames = None - self.filename_iter = None + self._filename_iter = None self.current_file_index = 0 diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py index 52055f40..ade75cd0 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/write_tiles.py @@ -95,7 +95,7 @@ def __init__( self.sample_cache: Optional[SampleCache] self._write_filenames: Optional[list[str]] = write_filenames - self.filename_iter: Optional[Iterator[str]] = ( + self._filename_iter: Optional[Iterator[str]] = ( iter(write_filenames) if write_filenames is not None else None ) if write_filenames is not None and n_samples_per_file is not None: @@ -122,7 +122,7 @@ def set_file_data(self, write_filenames: list[str], n_samples_per_file: list[int "equal length." ) self._write_filenames = write_filenames - self.filename_iter = iter(write_filenames) + self._filename_iter = iter(write_filenames) self.sample_cache = SampleCache(n_samples_per_file) def write_batch( @@ -170,7 +170,7 @@ def write_batch( ) # assert for mypy assert ( - self.filename_iter is not None + self._filename_iter is not None ), "`filename_iter` is `None` should be set by `set_file_data`." # TODO: move dataset type check somewhere else @@ -209,14 +209,14 @@ def write_batch( data = np.concatenate(samples) # write prediction - file_name = next(self.filename_iter) + file_name = next(self._filename_iter) file_path = (dirpath / file_name).with_suffix(self.write_extension) self.write_func(file_path=file_path, img=data, **self.write_func_kwargs) def reset(self) -> None: """Reset the internal attributes.""" self._write_filenames = None - self.filename_iter = None + self._filename_iter = None if self.tile_cache is not None: self.tile_cache.reset() if self.sample_cache is not None: diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index aab22bfc..565c9c3f 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -243,4 +243,4 @@ def test_reset(cache_tiles_strategy: WriteTiles): cache_tiles_strategy.reset() assert cache_tiles_strategy._write_filenames is None - assert cache_tiles_strategy.filename_iter is None + assert cache_tiles_strategy._filename_iter is None From 4c3b4ba3fce4b841083d71a3d4c276876525611c Mon Sep 17 00:00:00 2001 From: melisande-c Date: Fri, 29 Nov 2024 17:30:16 +0100 Subject: [PATCH 33/38] feat: README with class and sequence diagrams --- .../prediction_writer_callback/README.md | 119 ++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 tests/lightning/callbacks/prediction_writer_callback/README.md diff --git a/tests/lightning/callbacks/prediction_writer_callback/README.md b/tests/lightning/callbacks/prediction_writer_callback/README.md new file mode 100644 index 00000000..7f39794a --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/README.md @@ -0,0 +1,119 @@ +# `PredictionWriterCallback` + +## Class diagram for `PredictionWriteCallback` related classes +```mermaid +classDiagram + PredictionWriterCallback*--WriteStrategy : composition + WriteStrategy<--WriteTiles : implements + WriteStrategy<--WriteImage : implements + WriteStrategy<--WriteZarrTiles : implements + WriteTiles*--TileCache : composition + WriteTiles*--SampleCache : composition + WriteImage*--SampleCache : composition + + class PredictionWriterCallback + PredictionWriterCallback : +bool writing_predictions + PredictionWriterCallback : +WriteStrategy write_strategy + PredictionWriterCallback : +write_on_batch_end(...) + PredictionWriterCallback : +on_predict_epoch_end(...) + + class WriteStrategy + <> WriteStrategy + WriteStrategy : +write_batch(...)* + WriteStrategy : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file)* + WriteStrategy : +reset()* + + class WriteTiles + WriteTiles : +WriteFunc write_func + + WriteTiles : +TileCache tile_cache + WriteTiles : +SampleCache sample_cache + WriteTiles : +write_batch(...) + WriteTiles : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file) + WriteTiles : +reset() + + class WriteImage + WriteImage : +WriteFunc write_func + WriteImage : +SampleCache sample_cache + WriteImage : +write_batch(...) + WriteImage : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file) + WriteImage : +reset() + + class WriteZarrTiles + WriteZarrTiles : +write_batch(...) NotImplemented + WriteZarrTiles : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file) NotImplemented + WriteZarrTiles : +reset() NotImplemented + + class TileCache + TileCache : +list~NDArray~ array_cache + TileCache : +list~TileInformation~ tile_info_cache + TileCache : +add(NDArray, list~TileInformation~ item) + TileCache : +has_last_tile() bool + TileCache : +pop_image_tiles() NDArray, list~TileInformation~ + TileCache : +reset() + + class SampleCache + SampleCache : +list~int~ n_samples_per_file + SampleCache : +Iterator n_samples_iter + SampleCache : +int n_samples + SampleCache : +sample_cache list~NDArray~ + SampleCache : +add(NDArray item) + SampleCache : +has_all_file_samples() bool + SampleCache : +pop_file_samples() list~NDArray~ + SampleCache : +reset() +``` + +## Sequence diagram for writing tiles + +```mermaid +sequenceDiagram + participant Trainer + participant PredictionWriterCallback + participant WriteTiles + participant TileCache + + Trainer->>PredictionWriterCallback: write_on_batch_end(batch, ...) + activate PredictionWriterCallback + activate PredictionWriterCallback + PredictionWriterCallback->>WriteTiles: write_batch(batch, ...) + activate WriteTiles + activate WriteTiles + activate WriteTiles + WriteTiles->>TileCache: add(batch) + activate TileCache + WriteTiles ->> TileCache: has_last_tile() + TileCache -->> WriteTiles: True/False + deactivate TileCache + alt If does not have last tile + WriteTiles -->> PredictionWriterCallback: return + deactivate WriteTiles + PredictionWriterCallback -->> Trainer: return + deactivate PredictionWriterCallback + else If has last tile + WriteTiles ->> TileCache: pop_image_tiles() + activate TileCache + TileCache -->> WriteTiles: tiles, tile_infos + deactivate TileCache + Note right of WriteTiles: Tiles are stitched to create prediction_image. + WriteTiles ->> SampleCache: add(prediction_image) + activate SampleCache + WriteTiles ->> SampleCache: has_all_file_samples() + SampleCache -->> WriteTiles: True/False + deactivate SampleCache + alt If does not have all samples + WriteTiles -->> PredictionWriterCallback: return + deactivate WriteTiles + else If has all samples + WriteTiles ->> SampleCache: pop_file_samples() + activate SampleCache + SampleCache -->> WriteTiles: samples + deactivate SampleCache + Note right of WriteTiles: Concatenated samples are written to disk. + WriteTiles -->> PredictionWriterCallback: return + deactivate WriteTiles + end + PredictionWriterCallback -->> Trainer: return + deactivate PredictionWriterCallback + end + +``` \ No newline at end of file From cb9e288f68dc4a2893d7c996a5f72f8fed932447 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Fri, 29 Nov 2024 17:32:23 +0100 Subject: [PATCH 34/38] fix: mv readme to src from tests --- .../lightning/callbacks/prediction_writer_callback/README.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {tests => src/careamics}/lightning/callbacks/prediction_writer_callback/README.md (100%) diff --git a/tests/lightning/callbacks/prediction_writer_callback/README.md b/src/careamics/lightning/callbacks/prediction_writer_callback/README.md similarity index 100% rename from tests/lightning/callbacks/prediction_writer_callback/README.md rename to src/careamics/lightning/callbacks/prediction_writer_callback/README.md From 14f7f1ed638fc58ff4bc1a3431fde57af31ba894 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Tue, 10 Dec 2024 12:26:56 +0100 Subject: [PATCH 35/38] test: fix import error --- .../prediction_writer_callback/conftest.py | 71 +++++++++++++++++++ .../test_cache_tiles_write_strategy.py | 14 ++-- .../test_tile_cache.py | 10 ++- .../prediction_writer_callback/utils.py | 63 ---------------- 4 files changed, 82 insertions(+), 76 deletions(-) delete mode 100644 tests/lightning/callbacks/prediction_writer_callback/utils.py diff --git a/tests/lightning/callbacks/prediction_writer_callback/conftest.py b/tests/lightning/callbacks/prediction_writer_callback/conftest.py index cee80ffd..281bd8b4 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/conftest.py +++ b/tests/lightning/callbacks/prediction_writer_callback/conftest.py @@ -1,11 +1,82 @@ +from typing import Callable from unittest.mock import Mock +import numpy as np import pytest +from numpy.typing import NDArray +from careamics.config.tile_information import TileInformation +from careamics.dataset.tiling import extract_tiles from careamics.file_io import WriteFunc +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + caches, +) @pytest.fixture def write_func(): """Mock `WriteFunc`.""" return Mock(spec=WriteFunc) + + +@pytest.fixture +def patch_tile_cache() -> ( + Callable[[caches.TileCache, NDArray, list[TileInformation]], None] +): + def inner( + tile_cache: caches.TileCache, + tiles: list[NDArray], + tile_infos: list[TileInformation], + ) -> None: + """ + Patch simulated tile cache into `strategy`. + + Parameters + ---------- + tile_cache : TileCache + Tile cache used in `WriteTiles` write strategy class. + tiles : list of NDArray + Tiles to patch into `strategy.tile_cache`. + tile_infos : list of TileInformation + Corresponding tile information to patch into `strategy.tile_info_cache`. + """ + tile_cache.add((np.concatenate(tiles), tile_infos)) + + return inner + + +@pytest.fixture +def create_tiles() -> Callable[[int], tuple[list[NDArray], list[TileInformation]]]: + def inner(n_samples: int) -> tuple[list[NDArray], list[TileInformation]]: + """ + Create a set of tiles from `n_samples`. + + To create the tiles the following parameters, `tile_size=(4, 4)` and + `overlaps=(2, 2)`, on an input array with shape (`n_samples`, 1, 8, 8); this + results in 9 tiles per sample. + + Parameters + ---------- + n_samples : int + Number of samples to simulate the tiles from. + + Returns + ------- + tuple of (list of NDArray), list of TileInformation)) + Tuple where first element is the list of tiles and second element is a list + of corresponding tile information. + """ + + input_shape = (n_samples, 1, 8, 8) + tile_size = (4, 4) + tile_overlap = (2, 2) + + arr = np.arange(np.prod(input_shape)).reshape(input_shape) + + all_tiles = list(extract_tiles(arr, tile_size, tile_overlap)) + tiles = [output[0] for output in all_tiles] + tile_infos = [output[1] for output in all_tiles] + + return tiles, tile_infos + + return inner diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py index 565c9c3f..191001d8 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py @@ -13,8 +13,6 @@ WriteTiles, ) -from .utils import create_tiles, patch_tile_cache - @pytest.fixture def cache_tiles_strategy(write_func) -> WriteTiles: @@ -43,7 +41,7 @@ def cache_tiles_strategy(write_func) -> WriteTiles: # TODO: Move to test tile cache -def test_last_tiles(cache_tiles_strategy): +def test_last_tiles(cache_tiles_strategy, create_tiles, patch_tile_cache): """Test `CacheTiles.last_tile` property.""" # all tiles of 1 samples with 9 tiles @@ -58,7 +56,7 @@ def test_last_tiles(cache_tiles_strategy): assert cached_last_tiles == last_tiles -def test_write_batch_no_last_tile(cache_tiles_strategy): +def test_write_batch_no_last_tile(cache_tiles_strategy, create_tiles, patch_tile_cache): """ Test `CacheTiles.write_batch` when there is no last tile added to the cache. @@ -113,7 +111,7 @@ def test_write_batch_no_last_tile(cache_tiles_strategy): assert extended_tile_infos == cache_tiles_strategy.tile_cache.tile_info_cache -def test_write_batch_last_tile(cache_tiles_strategy): +def test_write_batch_last_tile(cache_tiles_strategy, create_tiles, patch_tile_cache): """ Test `CacheTiles.write_batch` when there is a last tile added to the cache. @@ -189,7 +187,9 @@ def test_write_batch_last_tile(cache_tiles_strategy): assert remaining_tile_info == cache_tiles_strategy.tile_cache.tile_info_cache[0] -def test_write_batch_raises(cache_tiles_strategy: WriteTiles): +def test_write_batch_raises( + cache_tiles_strategy: WriteTiles, create_tiles, patch_tile_cache +): """Test write batch raises a ValueError if the filenames have not been set.""" # all tiles of 2 samples with 9 tiles tiles, tile_infos = create_tiles(n_samples=2) @@ -232,7 +232,7 @@ def test_write_batch_raises(cache_tiles_strategy: WriteTiles): ) -def test_reset(cache_tiles_strategy: WriteTiles): +def test_reset(cache_tiles_strategy: WriteTiles, create_tiles, patch_tile_cache): """Test CacheTiles.reset works as expected""" # all tiles of 1 samples with 9 tiles tiles, tile_infos = create_tiles(n_samples=1) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py b/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py index 988aabe2..bd64a8fc 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py @@ -7,13 +7,11 @@ caches, ) -from .utils import create_tiles, patch_tile_cache - def test_add(): ... -def test_has_last_tile_true(): +def test_has_last_tile_true(create_tiles, patch_tile_cache): """ Test `TileCache.has_last_tile` returns true when there is a last tile. """ @@ -26,7 +24,7 @@ def test_has_last_tile_true(): assert tile_cache.has_last_tile() -def test_has_last_tile_false(): +def test_has_last_tile_false(create_tiles, patch_tile_cache): """Test `TileCache.has_last_tile` returns false when there is not a last tile.""" tile_cache = caches.TileCache() @@ -38,7 +36,7 @@ def test_has_last_tile_false(): assert not tile_cache.has_last_tile() -def test_pop_image_tiles(): +def test_pop_image_tiles(create_tiles, patch_tile_cache): """ Test `TileCache.has_last_tile` removes the tiles up until the first "last tile". """ @@ -62,7 +60,7 @@ def test_pop_image_tiles(): assert tile_cache.tile_info_cache[0] == tile_infos[9] -def test_pop_image_tiles_error(): +def test_pop_image_tiles_error(create_tiles, patch_tile_cache): """Test `CacheTiles._last_tile_index` raises an error when there is no last tile.""" tile_cache = caches.TileCache() diff --git a/tests/lightning/callbacks/prediction_writer_callback/utils.py b/tests/lightning/callbacks/prediction_writer_callback/utils.py deleted file mode 100644 index 8cc4047d..00000000 --- a/tests/lightning/callbacks/prediction_writer_callback/utils.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Utility functions to be used in tests relating to the `PredictionWriterCallback`.""" - -import numpy as np -from numpy.typing import NDArray - -from careamics.config.tile_information import TileInformation -from careamics.dataset.tiling import extract_tiles -from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( - caches, -) - - -def patch_tile_cache( - tile_cache: caches.TileCache, - tiles: list[NDArray], - tile_infos: list[TileInformation], -) -> None: - """ - Patch simulated tile cache into `strategy`. - - Parameters - ---------- - tile_cache : TileCache - Tile cache used in `WriteTiles` write strategy class. - tiles : list of NDArray - Tiles to patch into `strategy.tile_cache`. - tile_infos : list of TileInformation - Corresponding tile information to patch into `strategy.tile_info_cache`. - """ - tile_cache.add((np.concatenate(tiles), tile_infos)) - - -def create_tiles(n_samples: int) -> tuple[list[NDArray], list[TileInformation]]: - """ - Create a set of tiles from `n_samples`. - - To create the tiles the following parameters, `tile_size=(4, 4)` and - `overlaps=(2, 2)`, on an input array with shape (`n_samples`, 1, 8, 8); this - results in 9 tiles per sample. - - Parameters - ---------- - n_samples : int - Number of samples to simulate the tiles from. - - Returns - ------- - tuple of (list of NDArray), list of TileInformation)) - Tuple where first element is the list of tiles and second element is a list - of corresponding tile information. - """ - - input_shape = (n_samples, 1, 8, 8) - tile_size = (4, 4) - tile_overlap = (2, 2) - - arr = np.arange(np.prod(input_shape)).reshape(input_shape) - - all_tiles = list(extract_tiles(arr, tile_size, tile_overlap)) - tiles = [output[0] for output in all_tiles] - tile_infos = [output[1] for output in all_tiles] - - return tiles, tile_infos From 32f45afa345acfde8b60f5fce50624187bae6d0b Mon Sep 17 00:00:00 2001 From: melisande-c Date: Tue, 10 Dec 2024 17:00:51 +0100 Subject: [PATCH 36/38] test: add missing TileCache and SampleCache tests --- .../write_strategy/caches.py | 2 +- .../prediction_writer_callback/conftest.py | 15 ++- .../test_sample_cache.py | 120 ++++++++++++++---- .../test_tile_cache.py | 30 ++++- 4 files changed, 139 insertions(+), 28 deletions(-) diff --git a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py index b17b70a6..43b973f4 100644 --- a/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py +++ b/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy/caches.py @@ -125,7 +125,7 @@ def add(self, item: NDArray): Parameters ---------- item : numpy.ndarray - A single predicted sample. + A set of predicted samples. """ self.sample_cache.extend(np.split(item, item.shape[0])) diff --git a/tests/lightning/callbacks/prediction_writer_callback/conftest.py b/tests/lightning/callbacks/prediction_writer_callback/conftest.py index 281bd8b4..798cb11b 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/conftest.py +++ b/tests/lightning/callbacks/prediction_writer_callback/conftest.py @@ -21,7 +21,7 @@ def write_func(): @pytest.fixture def patch_tile_cache() -> ( - Callable[[caches.TileCache, NDArray, list[TileInformation]], None] + Callable[[caches.TileCache, list[NDArray], list[TileInformation]], None] ): def inner( tile_cache: caches.TileCache, @@ -80,3 +80,16 @@ def inner(n_samples: int) -> tuple[list[NDArray], list[TileInformation]]: return tiles, tile_infos return inner + + +@pytest.fixture +def samples() -> tuple[NDArray, list[int]]: + n_samples_per_file = [3, 1, 2] + shapes = [(16, 16), (8, 8), (12, 12)] + sample_set = [] + for n_samples, spatial_shape in zip(n_samples_per_file, shapes): + shape = (n_samples, 1, *spatial_shape) + sample = np.arange(np.prod(shape)).reshape(shape) + sample_set.extend(np.split(sample, n_samples)) + + return sample_set, n_samples_per_file diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py b/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py index 234600ee..248fa5b7 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_sample_cache.py @@ -1,25 +1,99 @@ """Test the utility `SampleCache` class used by the WriteTile classes.""" - -def test_add(): ... - - -def test_has_all_file_samples_true(): ... - - -def test_has_all_file_samples_false(): ... - - -def test_has_all_file_samples_error(): ... - - -def test_pop_file_samples(): ... - - -def test_pop_file_samples_last_sample(): ... - - -def test_pop_file_samples_error(): ... - - -def test_reset(): ... +import numpy as np +import pytest + +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + caches, +) + + +def test_add(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + sample_cache.add(samples[0]) + np.testing.assert_array_equal(sample_cache.sample_cache[0], samples[0]) + + +def test_has_all_file_samples_true(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + for i in range(n_samples_per_file[0]): + sample_cache.add(samples[i]) + assert sample_cache.has_all_file_samples() + + +def test_has_all_file_samples_false(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + for i in range(n_samples_per_file[0] - 1): + sample_cache.add(samples[i]) + assert not sample_cache.has_all_file_samples() + + +def test_has_all_file_samples_error(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + # simulate iterating through all files + for sample in samples: + sample_cache.add(sample) + for _ in n_samples_per_file: + sample_cache.pop_file_samples() + + with pytest.raises(ValueError): + sample_cache.has_all_file_samples() + + +def test_pop_file_samples(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + + n = 4 + for i in range(n): + sample_cache.add(samples[i]) + + first_file_n_samples = n_samples_per_file[0] + + # assert popped samples correct + popped_samples = sample_cache.pop_file_samples() + for i in range(first_file_n_samples): + np.testing.assert_array_equal(popped_samples[i], samples[i]) + + # assert remaining sample correct + np.testing.assert_array_equal( + sample_cache.sample_cache[0], samples[first_file_n_samples] + ) + + +def test_pop_file_samples_error(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + sample_cache.add(samples[0]) + + # should raise an error because sample cache does not contain all file samples + with pytest.raises(ValueError): + sample_cache.pop_file_samples() + + +def test_reset(samples): + # n_samples_per_file = [3, 1, 2] + # simulated batch size = 1 + samples, n_samples_per_file = samples + sample_cache = caches.SampleCache(n_samples_per_file) + sample_cache.add(samples[0]) + + sample_cache.reset() + assert len(sample_cache.sample_cache) == 0 + assert next(sample_cache.n_samples_iter) == n_samples_per_file[0] diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py b/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py index bd64a8fc..f919130a 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_tile_cache.py @@ -8,7 +8,22 @@ ) -def test_add(): ... +def test_add(create_tiles, patch_tile_cache): + """Test `TileCache.add` extends the internal caches as expected.""" + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tile_list, tile_infos = create_tiles(n_samples=1) + + n = 4 + # patch in tiles up to n into tile_cache + patch_tile_cache(tile_cache, tile_list[:n], tile_infos[:n]) + # add remaining tiles to tile_cache + tile_cache.add((np.concatenate(tile_list[n:]), tile_infos[n:])) + + np.testing.assert_array_equal( + np.concatenate(tile_cache.array_cache), np.concatenate(tile_list) + ) + assert tile_cache.tile_info_cache == tile_infos def test_has_last_tile_true(create_tiles, patch_tile_cache): @@ -61,7 +76,7 @@ def test_pop_image_tiles(create_tiles, patch_tile_cache): def test_pop_image_tiles_error(create_tiles, patch_tile_cache): - """Test `CacheTiles._last_tile_index` raises an error when there is no last tile.""" + """Test `TileCache.pop_image_tiles` raises an error when there is no last tile.""" tile_cache = caches.TileCache() # all tiles of 1 samples with 9 tiles @@ -73,4 +88,13 @@ def test_pop_image_tiles_error(create_tiles, patch_tile_cache): tile_cache.pop_image_tiles() -def test_reset(): ... +def test_reset(create_tiles, patch_tile_cache): + """Test `TileCache.reset resets the cached tiles as expected.""" + tile_cache = caches.TileCache() + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + patch_tile_cache(tile_cache, tiles, tile_infos) + + tile_cache.reset() + assert len(tile_cache.array_cache) == 0 + assert len(tile_cache.tile_info_cache) == 0 From f6dc8818c6abc0c543fc302d457ac1f2529e18d1 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Wed, 11 Dec 2024 17:16:31 +0100 Subject: [PATCH 37/38] test: complete tests from WriteTiles strategy --- .../test_cache_tiles_write_strategy.py | 246 --------------- .../test_write_tiles_strategy.py | 294 ++++++++++++++++++ 2 files changed, 294 insertions(+), 246 deletions(-) delete mode 100644 tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py create mode 100644 tests/lightning/callbacks/prediction_writer_callback/test_write_tiles_strategy.py diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py deleted file mode 100644 index 191001d8..00000000 --- a/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Test `CacheTiles` class.""" - -from pathlib import Path -from unittest.mock import Mock, patch - -import numpy as np -import pytest -from pytorch_lightning import LightningModule, Trainer -from torch.utils.data import DataLoader - -from careamics.dataset import IterableTiledPredDataset -from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( - WriteTiles, -) - - -@pytest.fixture -def cache_tiles_strategy(write_func) -> WriteTiles: - """ - Initialized `CacheTiles` class. - - Parameters - ---------- - write_func : WriteFunc - Write function. (Comes from fixture). - - Returns - ------- - CacheTiles - Initialized `CacheTiles` class. - """ - write_extension = ".ext" - write_func_kwargs = {} - return WriteTiles( - write_func=write_func, - write_extension=write_extension, - write_func_kwargs=write_func_kwargs, - write_filenames=None, - n_samples_per_file=None, - ) - - -# TODO: Move to test tile cache -def test_last_tiles(cache_tiles_strategy, create_tiles, patch_tile_cache): - """Test `CacheTiles.last_tile` property.""" - - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - patch_tile_cache(cache_tiles_strategy.tile_cache, tiles, tile_infos) - - last_tiles = [False, False, False, False, False, False, False, False, True] - cached_last_tiles = [ - tile_info.last_tile - for tile_info in cache_tiles_strategy.tile_cache.tile_info_cache - ] - assert cached_last_tiles == last_tiles - - -def test_write_batch_no_last_tile(cache_tiles_strategy, create_tiles, patch_tile_cache): - """ - Test `CacheTiles.write_batch` when there is no last tile added to the cache. - - Expected behaviour is that the batch is added to the cache. - """ - - # all tiles of 1 samples with 9 tiles - n_samples = 1 - tiles, tile_infos = create_tiles(n_samples=n_samples) - - # simulate adding a batch that will not contain the last tile - n_tiles = 4 - batch_size = 2 - patch_tile_cache( - cache_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] - ) - next_batch = ( - np.concatenate(tiles[n_tiles : n_tiles + batch_size]), - tile_infos[n_tiles : n_tiles + batch_size], - ) - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - mock_dataset = Mock(spec=IterableTiledPredDataset) - dataloader_idx = 0 - trainer.predict_dataloaders = [Mock(spec=DataLoader)] - trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset - - cache_tiles_strategy.set_file_data( - write_filenames=["file_1.tiff"], n_samples_per_file=[n_samples] - ) - cache_tiles_strategy.write_batch( - trainer=trainer, - pl_module=Mock(spec=LightningModule), - prediction=next_batch, - batch_indices=Mock(), - batch=next_batch, # does not contain the last tile - batch_idx=3, - dataloader_idx=dataloader_idx, - dirpath="predictions", - ) - - extended_tiles = tiles[: n_tiles + batch_size] - extended_tile_infos = tile_infos[: n_tiles + batch_size] - - assert all( - np.array_equal( - extended_tiles[i], cache_tiles_strategy.tile_cache.array_cache[i] - ) - for i in range(n_tiles + batch_size) - ) - assert extended_tile_infos == cache_tiles_strategy.tile_cache.tile_info_cache - - -def test_write_batch_last_tile(cache_tiles_strategy, create_tiles, patch_tile_cache): - """ - Test `CacheTiles.write_batch` when there is a last tile added to the cache. - - Expected behaviour is that the cache is cleared and the write func is called. - """ - - # all tiles of 2 samples with 9 tiles - n_samples = 2 - tiles, tile_infos = create_tiles(n_samples=n_samples) - - # simulate adding a batch that will contain the last tile - n_tiles = 8 - batch_size = 2 - patch_tile_cache( - cache_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] - ) - next_batch = ( - np.concatenate(tiles[n_tiles : n_tiles + batch_size]), - tile_infos[n_tiles : n_tiles + batch_size], - ) - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - mock_dataset = Mock(spec=IterableTiledPredDataset) - dataloader_idx = 0 - trainer.predict_dataloaders = [Mock(spec=DataLoader)] - trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset - - file_names = [f"file_{i}" for i in range(n_samples)] - n_samples_per_file = [1 for _ in range(n_samples)] - - # call write batch - dirpath = Path("predictions") - cache_tiles_strategy.set_file_data( - write_filenames=file_names, n_samples_per_file=n_samples_per_file - ) - - # mocking stitch_prediction_single because to assert if WriteFunc was called - with patch( - "careamics.lightning.callbacks.prediction_writer_callback.write_strategy" - + ".write_tiles.stitch_prediction_single", - ) as mock_stitch_prediction_single: - mock_prediction = np.arange(64).reshape(1, 1, 8, 8) - mock_stitch_prediction_single.return_value = mock_prediction - cache_tiles_strategy.write_batch( - trainer=trainer, - pl_module=Mock(spec=LightningModule), - prediction=next_batch, - batch_indices=Mock(), - batch=next_batch, # contains the last tile - batch_idx=3, - dataloader_idx=dataloader_idx, - dirpath=dirpath, - ) - - # assert write_func is called as expected - write_func_call_args = cache_tiles_strategy.write_func.call_args.kwargs - assert write_func_call_args["file_path"] == Path("predictions/file_0.ext") - np.testing.assert_array_equal(write_func_call_args["img"], mock_prediction) - - # Tile of the next image (should remain in the cache) - remaining_tile = tiles[9] - remaining_tile_info = tile_infos[9] - - # assert cache cleared as expected - assert len(cache_tiles_strategy.tile_cache.array_cache) == 1 - assert np.array_equal( - remaining_tile, cache_tiles_strategy.tile_cache.array_cache[0] - ) - assert remaining_tile_info == cache_tiles_strategy.tile_cache.tile_info_cache[0] - - -def test_write_batch_raises( - cache_tiles_strategy: WriteTiles, create_tiles, patch_tile_cache -): - """Test write batch raises a ValueError if the filenames have not been set.""" - # all tiles of 2 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=2) - - # simulate adding a batch that will contain the last tile - n_tiles = 8 - batch_size = 2 - patch_tile_cache( - cache_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] - ) - next_batch = ( - np.concatenate(tiles[n_tiles : n_tiles + batch_size]), - tile_infos[n_tiles : n_tiles + batch_size], - ) - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - - # mock trainer and datasets - trainer = Mock(spec=Trainer) - mock_dataset = Mock(spec=IterableTiledPredDataset) - dataloader_idx = 0 - trainer.predict_dataloaders = [Mock(spec=DataLoader)] - trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset - - with pytest.raises(ValueError): - assert cache_tiles_strategy._write_filenames is None - - # call write batch - dirpath = Path("predictions") - cache_tiles_strategy.write_batch( - trainer=trainer, - pl_module=Mock(spec=LightningModule), - prediction=next_batch, - batch_indices=Mock(), - batch=next_batch, # contains the last tile - batch_idx=3, - dataloader_idx=dataloader_idx, - dirpath=dirpath, - ) - - -def test_reset(cache_tiles_strategy: WriteTiles, create_tiles, patch_tile_cache): - """Test CacheTiles.reset works as expected""" - # all tiles of 1 samples with 9 tiles - tiles, tile_infos = create_tiles(n_samples=1) - # don't include last tile - patch_tile_cache(cache_tiles_strategy.tile_cache, tiles[:-1], tile_infos[:-1]) - - cache_tiles_strategy.set_file_data(write_filenames=["file"], n_samples_per_file=[1]) - cache_tiles_strategy.reset() - - assert cache_tiles_strategy._write_filenames is None - assert cache_tiles_strategy._filename_iter is None diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_tiles_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_tiles_strategy.py new file mode 100644 index 00000000..5c32a728 --- /dev/null +++ b/tests/lightning/callbacks/prediction_writer_callback/test_write_tiles_strategy.py @@ -0,0 +1,294 @@ +"""Test `CacheTiles` class.""" + +from pathlib import Path +from unittest.mock import Mock + +import numpy as np +import pytest +import tifffile +from pytorch_lightning import LightningModule, Trainer +from torch.utils.data import DataLoader + +from careamics.dataset import IterableTiledPredDataset +from careamics.file_io.write import write_tiff +from careamics.lightning.callbacks.prediction_writer_callback.write_strategy import ( + WriteTiles, +) +from careamics.prediction_utils import stitch_prediction + + +@pytest.fixture +def write_tiles_strategy() -> WriteTiles: + """ + Initialized `WriteTiles` class. + + Parameters + ---------- + write_func : WriteFunc + Write function. (Comes from fixture). + + Returns + ------- + CacheTiles + Initialized `CacheTiles` class. + """ + write_extension = ".tiff" + write_func_kwargs = {} + return WriteTiles( + write_func=write_tiff, + write_extension=write_extension, + write_func_kwargs=write_func_kwargs, + write_filenames=None, + n_samples_per_file=None, + ) + + +def test_write_batch_no_last_tile( + tmp_path, + write_tiles_strategy: WriteTiles, + create_tiles, + patch_tile_cache, +): + """ + Test behaviour when the last tile of a sample is not present. + + `WriteTiles.write_batch` should cache the tiles in the given batch. + """ + # all tiles of 1 samples with 9 tiles + n_samples = 1 + tiles, tile_infos = create_tiles(n_samples=n_samples) + + # simulate adding a batch that will not contain the last tile + n_tiles = 4 + batch_size = 2 + patch_tile_cache( + write_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] + ) + next_batch = ( + np.concatenate(tiles[n_tiles : n_tiles + batch_size]), + tile_infos[n_tiles : n_tiles + batch_size], + ) + + # mock trainer and datasets (difficult to set up true classes) + trainer: Trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterableTiledPredDataset) + dataloader_idx = 0 + trainer.predict_dataloaders = [Mock(spec=DataLoader)] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + write_tiles_strategy.set_file_data( + write_filenames=["file_1.tif"], n_samples_per_file=[n_samples] + ) + write_tiles_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=next_batch, + batch_indices=Mock(), + batch=next_batch, # does not contain the last tile + batch_idx=3, + dataloader_idx=dataloader_idx, + dirpath=tmp_path / "predictions", + ) + + extended_tiles = tiles[: n_tiles + batch_size] + extended_tile_infos = tile_infos[: n_tiles + batch_size] + + # assert tiles and tile infos are caches + for i in range(n_tiles + batch_size): + np.testing.assert_array_equal( + extended_tiles[i], write_tiles_strategy.tile_cache.array_cache[i] + ) + assert extended_tile_infos == write_tiles_strategy.tile_cache.tile_info_cache + assert len(write_tiles_strategy.sample_cache.sample_cache) == 0 + + +def test_write_batch_has_last_tile_no_last_sample( + tmp_path, + write_tiles_strategy: WriteTiles, + create_tiles, + patch_tile_cache, +): + """ + Test behaviour when the last tile of a sample is present, but not the last sample. + + `WriteTiles.write_batch` should cache the resulting stitched sample. + """ + # all tiles of 2 samples with 9 tiles each + n_samples = 2 + tiles, tile_infos = create_tiles(n_samples=n_samples) + + # simulate adding a batch that will not contain the last sample + n_tiles = 8 + batch_size = 2 + patch_tile_cache( + write_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] + ) + # (ext batch will include the last tile of the first sample + next_batch = ( + np.concatenate(tiles[n_tiles : n_tiles + batch_size]), + tile_infos[n_tiles : n_tiles + batch_size], + ) + + # mock trainer and datasets (difficult to set up true classes) + trainer: Trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterableTiledPredDataset) + dataloader_idx = 0 + trainer.predict_dataloaders = [Mock(spec=DataLoader)] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + write_tiles_strategy.set_file_data( + write_filenames=["file_1.tif"], n_samples_per_file=[n_samples] + ) + write_tiles_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=next_batch, + batch_indices=Mock(), + batch=next_batch, # contains the last tile of the first sample + batch_idx=3, + dataloader_idx=dataloader_idx, + dirpath=tmp_path / "predictions", + ) + + stitched_sample1 = stitch_prediction(tiles[:9], tile_infos[:9])[0] + + # assert first tile from the second sample is in the tile cache + np.testing.assert_array_equal( + tiles[9], write_tiles_strategy.tile_cache.array_cache[0] + ) + assert [tile_infos[9]] == write_tiles_strategy.tile_cache.tile_info_cache + # assert the stitched sample is saved in the sample cache + np.testing.assert_array_equal( + stitched_sample1, write_tiles_strategy.sample_cache.sample_cache[0] + ) + + +def test_write_batch_has_last_sample( + tmp_path, + write_tiles_strategy: WriteTiles, + create_tiles, + patch_tile_cache, +): + """ + Test behaviour when the last sample of a file is present. + + `WriteTiles.write_batch` should write the resulting set of samples to disk. + """ + """ + Test behaviour when the last tile of a sample is present, but not the last sample. + + `WriteTiles.write_batch` should cache the resulting stitched sample. + """ + + # all tiles of 2 samples with 9 tiles each + n_samples = 2 + tiles, tile_infos = create_tiles(n_samples=n_samples) + + stitched_samples = stitch_prediction(tiles, tile_infos) + file_data = np.concatenate(stitched_samples) + + # simulate adding a batch that will not contain the last sample + batch_size = 2 + n_tiles = 16 + write_tiles_strategy.set_file_data( + write_filenames=["file_1.tif"], n_samples_per_file=[n_samples] + ) + patch_tile_cache( + write_tiles_strategy.tile_cache, tiles[9:n_tiles], tile_infos[9:n_tiles] + ) + # also patch_sample + write_tiles_strategy.sample_cache.sample_cache.append(stitched_samples[0]) + # (ext batch will include the last tile of the first sample + next_batch = ( + np.concatenate(tiles[n_tiles : n_tiles + batch_size]), + tile_infos[n_tiles : n_tiles + batch_size], + ) + + # mock trainer and datasets (difficult to set up true classes) + trainer: Trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterableTiledPredDataset) + dataloader_idx = 0 + trainer.predict_dataloaders = [Mock(spec=DataLoader)] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + # normally output directory creation is handled by PredictionWriterCallback + (tmp_path / "predictions").mkdir() + write_tiles_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=next_batch, + batch_indices=Mock(), + batch=next_batch, # contains the last tile of the last sample + batch_idx=3, + dataloader_idx=dataloader_idx, + dirpath=tmp_path / "predictions", + ) + + # assert caches are now empty + assert len(write_tiles_strategy.tile_cache.array_cache) == 0 + assert len(write_tiles_strategy.tile_cache.tile_info_cache) == 0 + assert len(write_tiles_strategy.sample_cache.sample_cache) == 0 + + assert (tmp_path / "predictions" / "file_1.tiff").is_file() + + load_file_data = tifffile.imread(tmp_path / "predictions" / "file_1.tiff") + np.testing.assert_array_equal(load_file_data, file_data) + + +def test_write_batch_raises( + write_tiles_strategy: WriteTiles, create_tiles, patch_tile_cache +): + """Test write batch raises a ValueError if the filenames have not been set.""" + # all tiles of 2 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=2) + + # simulate adding a batch that will contain the last tile + n_tiles = 8 + batch_size = 2 + patch_tile_cache( + write_tiles_strategy.tile_cache, tiles[:n_tiles], tile_infos[:n_tiles] + ) + next_batch = ( + np.concatenate(tiles[n_tiles : n_tiles + batch_size]), + tile_infos[n_tiles : n_tiles + batch_size], + ) + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + + # mock trainer and datasets + trainer = Mock(spec=Trainer) + mock_dataset = Mock(spec=IterableTiledPredDataset) + dataloader_idx = 0 + trainer.predict_dataloaders = [Mock(spec=DataLoader)] + trainer.predict_dataloaders[dataloader_idx].dataset = mock_dataset + + with pytest.raises(ValueError): + assert write_tiles_strategy._write_filenames is None + + # call write batch + dirpath = Path("predictions") + write_tiles_strategy.write_batch( + trainer=trainer, + pl_module=Mock(spec=LightningModule), + prediction=next_batch, + batch_indices=Mock(), + batch=next_batch, # contains the last tile + batch_idx=3, + dataloader_idx=dataloader_idx, + dirpath=dirpath, + ) + + +def test_reset(write_tiles_strategy: WriteTiles, create_tiles, patch_tile_cache): + """Test CacheTiles.reset works as expected""" + # all tiles of 1 samples with 9 tiles + tiles, tile_infos = create_tiles(n_samples=1) + # don't include last tile + patch_tile_cache(write_tiles_strategy.tile_cache, tiles[:-1], tile_infos[:-1]) + + write_tiles_strategy.set_file_data(write_filenames=["file"], n_samples_per_file=[1]) + write_tiles_strategy.reset() + + assert write_tiles_strategy._write_filenames is None + assert write_tiles_strategy._filename_iter is None From f1c4b47f2513f0c73278ef1231738acd95e630d1 Mon Sep 17 00:00:00 2001 From: melisande-c Date: Wed, 11 Dec 2024 17:18:56 +0100 Subject: [PATCH 38/38] style: rename test file --- ...write_image_write_strategy.py => test_write_image_strategy.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/lightning/callbacks/prediction_writer_callback/{test_write_image_write_strategy.py => test_write_image_strategy.py} (100%) diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py b/tests/lightning/callbacks/prediction_writer_callback/test_write_image_strategy.py similarity index 100% rename from tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py rename to tests/lightning/callbacks/prediction_writer_callback/test_write_image_strategy.py