Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature(Write predictions to disk): Write strategies cache samples, to mirror source files organisation #324

Closed
wants to merge 43 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
091a518
feat(write strategy): add write_filenmaes attribute
melisande-c Sep 23, 2024
c4213a2
feat(write_strategy): add reset method
melisande-c Sep 23, 2024
dedfac6
fix(write on batch end): remove no longer needed ds instance type check
melisande-c Sep 23, 2024
14733ca
feat(write_strategy): add current file index attr
melisande-c Sep 23, 2024
99459b7
feat: update write_batch method to use write_filenames attribute
melisande-c Sep 23, 2024
3a1f031
feat(write prediction): remove no-longer used file_path_utils.py
melisande-c Sep 23, 2024
a82ed40
test: remove test_file_path_utils.py
melisande-c Sep 23, 2024
9ebecc1
test: update CacheTile initialisation
melisande-c Sep 23, 2024
42f01cd
test(CacheTiles): update - set write_filenames
melisande-c Sep 23, 2024
eb82fb3
test: update write_image_strategy tests - write_filenames attribute
melisande-c Sep 23, 2024
a7cfd5c
test: update prediction writer smoke tests - write_filenames attr
melisande-c Sep 23, 2024
d517849
feat: add write_filenames attribute to WriteStrategy Protocol
melisande-c Sep 23, 2024
95750e3
docs: attr doctstring
melisande-c Sep 23, 2024
927df49
test: write strategy reset
melisande-c Sep 23, 2024
377c0cd
test: write_batch raises if write_filenames is None
melisande-c Sep 23, 2024
42371f8
test: filenames reset after predict call in smoke
melisande-c Sep 23, 2024
3624123
refac: split write strategies into seperate modules
melisande-c Sep 24, 2024
b4fecac
refac: remove inheritance from protocol
melisande-c Sep 24, 2024
e062295
refac: rename CacheTiles to WriteTiles
melisande-c Sep 24, 2024
b6fa4ef
refac: extract tile cache to seperate class
melisande-c Sep 24, 2024
88e8e9d
feat: add SampleCache class
melisande-c Sep 25, 2024
39182d6
feat: add sample caching to write strategies
melisande-c Sep 25, 2024
67966b3
feat(write strategies): method to set file data
melisande-c Sep 25, 2024
bb86d1f
feat(write strategie protocol): method to set file data
melisande-c Sep 25, 2024
ee9be04
feat: replace filename indexing with iterator
melisande-c Sep 25, 2024
f92d9f5
Merge branch 'main' into mc/feat/cache_pred_samples
melisande-c Nov 5, 2024
7de654d
test(WriteImage strategy): fix tests for updated classes; fix: bugs
melisande-c Nov 12, 2024
5f5109d
test: update prediction callback write strategy unit tests
melisande-c Nov 12, 2024
2dd9e78
Merge branch 'main' into mc/feat/cache_pred_samples
melisande-c Nov 15, 2024
a779316
fix: bugs and tests since requiring n_samples_per_file param
melisande-c Nov 22, 2024
c5305e3
docs: add and update docs
melisande-c Nov 22, 2024
663f3a8
Merge branch 'main' into mc/feat/cache_pred_samples
melisande-c Nov 22, 2024
481c527
refac(prediction writer): rename utils.py to caches.py
melisande-c Nov 29, 2024
12de6d6
test: placeholder funcs for cache tests; move existing tests relating…
melisande-c Nov 29, 2024
f05af8a
style: make some attributes private
melisande-c Nov 29, 2024
4c3b4ba
feat: README with class and sequence diagrams
melisande-c Nov 29, 2024
cb9e288
fix: mv readme to src from tests
melisande-c Nov 29, 2024
14f7f1e
test: fix import error
melisande-c Dec 10, 2024
32f45af
test: add missing TileCache and SampleCache tests
melisande-c Dec 10, 2024
1dc5f90
Merge branch 'main' into mc/feat/cache_pred_samples
melisande-c Dec 10, 2024
b6fe0a8
Merge branch 'main' into mc/feat/cache_pred_samples
melisande-c Dec 10, 2024
f6dc881
test: complete tests from WriteTiles strategy
melisande-c Dec 11, 2024
f1c4b47
style: rename test file
melisande-c Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions src/careamics/lightning/callbacks/prediction_writer_callback/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# `PredictionWriterCallback`

## Class diagram for `PredictionWriteCallback` related classes
```mermaid
classDiagram
PredictionWriterCallback*--WriteStrategy : composition
WriteStrategy<--WriteTiles : implements
WriteStrategy<--WriteImage : implements
WriteStrategy<--WriteZarrTiles : implements
WriteTiles*--TileCache : composition
WriteTiles*--SampleCache : composition
WriteImage*--SampleCache : composition

class PredictionWriterCallback
PredictionWriterCallback : +bool writing_predictions
PredictionWriterCallback : +WriteStrategy write_strategy
PredictionWriterCallback : +write_on_batch_end(...)
PredictionWriterCallback : +on_predict_epoch_end(...)

class WriteStrategy
<<interface>> WriteStrategy
WriteStrategy : +write_batch(...)*
WriteStrategy : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file)*
WriteStrategy : +reset()*

class WriteTiles
WriteTiles : +WriteFunc write_func

WriteTiles : +TileCache tile_cache
WriteTiles : +SampleCache sample_cache
WriteTiles : +write_batch(...)
WriteTiles : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file)
WriteTiles : +reset()

class WriteImage
WriteImage : +WriteFunc write_func
WriteImage : +SampleCache sample_cache
WriteImage : +write_batch(...)
WriteImage : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file)
WriteImage : +reset()

class WriteZarrTiles
WriteZarrTiles : +write_batch(...) NotImplemented
WriteZarrTiles : +set_file_data(lists~str~ write_filenames, list~int~ n_samples_per_file) NotImplemented
WriteZarrTiles : +reset() NotImplemented

class TileCache
TileCache : +list~NDArray~ array_cache
TileCache : +list~TileInformation~ tile_info_cache
TileCache : +add(NDArray, list~TileInformation~ item)
TileCache : +has_last_tile() bool
TileCache : +pop_image_tiles() NDArray, list~TileInformation~
TileCache : +reset()

class SampleCache
SampleCache : +list~int~ n_samples_per_file
SampleCache : +Iterator n_samples_iter
SampleCache : +int n_samples
SampleCache : +sample_cache list~NDArray~
SampleCache : +add(NDArray item)
SampleCache : +has_all_file_samples() bool
SampleCache : +pop_file_samples() list~NDArray~
SampleCache : +reset()
```

## Sequence diagram for writing tiles

```mermaid
sequenceDiagram
participant Trainer
participant PredictionWriterCallback
participant WriteTiles
participant TileCache

Trainer->>PredictionWriterCallback: write_on_batch_end(batch, ...)
activate PredictionWriterCallback
activate PredictionWriterCallback
PredictionWriterCallback->>WriteTiles: write_batch(batch, ...)
activate WriteTiles
activate WriteTiles
activate WriteTiles
WriteTiles->>TileCache: add(batch)
activate TileCache
WriteTiles ->> TileCache: has_last_tile()
TileCache -->> WriteTiles: True/False
deactivate TileCache
alt If does not have last tile
WriteTiles -->> PredictionWriterCallback: return
deactivate WriteTiles
PredictionWriterCallback -->> Trainer: return
deactivate PredictionWriterCallback
else If has last tile
WriteTiles ->> TileCache: pop_image_tiles()
activate TileCache
TileCache -->> WriteTiles: tiles, tile_infos
deactivate TileCache
Note right of WriteTiles: Tiles are stitched to create prediction_image.
WriteTiles ->> SampleCache: add(prediction_image)
activate SampleCache
WriteTiles ->> SampleCache: has_all_file_samples()
SampleCache -->> WriteTiles: True/False
deactivate SampleCache
alt If does not have all samples
WriteTiles -->> PredictionWriterCallback: return
deactivate WriteTiles
else If has all samples
WriteTiles ->> SampleCache: pop_file_samples()
activate SampleCache
SampleCache -->> WriteTiles: samples
deactivate SampleCache
Note right of WriteTiles: Concatenated samples are written to disk.
WriteTiles -->> PredictionWriterCallback: return
deactivate WriteTiles
end
PredictionWriterCallback -->> Trainer: return
deactivate PredictionWriterCallback
end

```
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
"create_write_strategy",
"WriteStrategy",
"WriteImage",
"CacheTiles",
"WriteTiles",
"WriteTilesZarr",
"select_write_extension",
"select_write_func",
]

from .prediction_writer_callback import PredictionWriterCallback
from .write_strategy import CacheTiles, WriteImage, WriteStrategy, WriteTilesZarr
from .write_strategy import WriteImage, WriteStrategy, WriteTiles, WriteTilesZarr
from .write_strategy_factory import (
create_write_strategy,
select_write_extension,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import BasePredictionWriter
from torch.utils.data import DataLoader

from careamics.dataset import (
IterablePredDataset,
Expand Down Expand Up @@ -127,7 +126,7 @@ def from_write_func_params(
)
return cls(write_strategy=write_strategy, dirpath=dirpath)

def _init_dirpath(self, dirpath):
def _init_dirpath(self, dirpath: Union[Path, str]):
"""
Initialize directory path. Should only be called from `__init__`.

Expand Down Expand Up @@ -161,6 +160,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non
Stage of training e.g. 'predict', 'fit', 'validate'.
"""
super().setup(trainer, pl_module, stage)
# TODO: move directory creation to hook on_predict_epoch_start
if stage == "predict":
# make prediction output directory
logger.info("Making prediction output directory.")
Expand Down Expand Up @@ -202,25 +202,6 @@ def write_on_batch_end(
if not self.writing_predictions:
return

dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
dataloader: DataLoader = (
dataloaders[dataloader_idx]
if isinstance(dataloaders, list)
else dataloaders
)
dataset: ValidPredDatasets = dataloader.dataset
if not (
isinstance(dataset, IterablePredDataset)
or isinstance(dataset, IterableTiledPredDataset)
):
# Note: Error will be raised before here from the source type
# This is for extra redundancy of errors.
raise TypeError(
"Prediction dataset has to be `IterableTiledPredDataset` or "
"`IterablePredDataset`. Cannot be `InMemoryPredDataset` because "
"filenames are taken from the original file."
)

self.write_strategy.write_batch(
trainer=trainer,
pl_module=pl_module,
Expand All @@ -231,3 +212,21 @@ def write_on_batch_end(
dataloader_idx=dataloader_idx,
dirpath=self.dirpath,
)

def on_predict_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""
Lightning hook called at the end of prediction.

Resets write_strategy.

Parameters
----------
trainer : Trainer
PyTorch Lightning trainer.
pl_module : LightningModule
PyTorch Lightning module.
"""
# reset write_strategy to prevent bugs if trainer.predict is called twice.
self.write_strategy.reset()
Loading
Loading