Skip to content

Commit

Permalink
fix: Compatibility with tiled prediction during training
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Nov 12, 2024
1 parent 547fa2b commit d415d46
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
22 changes: 16 additions & 6 deletions src/careamics/lightning/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SupportedOptimizer,
SupportedScheduler,
)
from careamics.config.tile_information import TileInformation
from careamics.losses import loss_factory
from careamics.models.lvae.likelihoods import (
GaussianLikelihood,
Expand Down Expand Up @@ -163,19 +164,28 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
Any
Model output.
"""
# TODO refactor when redoing datasets
# hacky way to determine if it is PredictDataModule, otherwise there is a
# circular import to solve (isinstance)
is_prediction = hasattr(self._trainer.datamodule, "tiled")
# circular import to solve with isinstance
from_prediction = hasattr(self._trainer.datamodule, "tiled")
is_tiled = (
len(batch) > 1
and isinstance(batch[1], list)
and isinstance(batch[1][0], TileInformation)
)

if is_prediction and self._trainer.datamodule.tiled:
if is_tiled:
x, *aux = batch
else:
x = batch
aux = []

# apply test-time augmentation if available
# TODO: probably wont work with batch size > 1
if is_prediction and self._trainer.datamodule.prediction_config.tta_transforms:
if (
from_prediction
and self._trainer.datamodule.prediction_config.tta_transforms
):
tta = ImageRestorationTTA()
augmented_batch = tta.forward(x) # list of augmented tensors
augmented_output = []
Expand All @@ -191,12 +201,12 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
denorm = Denormalize(
image_means=(
self._trainer.datamodule.predict_dataset.image_means
if is_prediction
if from_prediction
else self._trainer.datamodule.train_dataset.image_stats.means
),
image_stds=(
self._trainer.datamodule.predict_dataset.image_stds
if is_prediction
if from_prediction
else self._trainer.datamodule.train_dataset.image_stats.stds
),
)
Expand Down
8 changes: 6 additions & 2 deletions tests/lightning/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ def test_fcn_module_unet_depth_3_channels_3D(n_channels):
assert y.shape == x.shape


def test_prediction_callback_during_training(minimum_configuration):
@pytest.mark.parametrize("tiled", [False, True])
def test_prediction_callback_during_training(minimum_configuration, tiled):
import numpy as np
from pytorch_lightning import Callback, Trainer

Expand Down Expand Up @@ -352,13 +353,16 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module):

self.data = convert_outputs(outputs, self.pred_datamodule.tiled)

array = np.arange(32 * 32).reshape((32, 32))
array = np.arange(64 * 64).reshape((64, 64))
pred_datamodule = create_predict_datamodule(
pred_data=array,
data_type=config.data_config.data_type,
axes=config.data_config.axes,
image_means=[11.8], # random placeholder
image_stds=[3.14],
tile_size=(16, 16) if tiled else None,
tile_overlap=(8, 8) if tiled else None,
batch_size=2,
)

predict_after_val_callback = CustomPredictAfterValidationCallback(
Expand Down

0 comments on commit d415d46

Please sign in to comment.