Skip to content

Commit

Permalink
Fix BSD n2v notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Apr 5, 2024
1 parent 390086e commit 8a4bf65
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 19 deletions.
47 changes: 32 additions & 15 deletions examples/2D/n2v/example_BSD68_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
"from pytorch_lightning import Trainer\n",
"\n",
"from careamics import CAREamicsModule\n",
"from careamics.lightning_prediction import CAREamicsFiring\n",
"from careamics.ligthning_datamodule import (\n",
"from careamics.lightning_datamodule import (\n",
" CAREamicsPredictDataModule,\n",
" CAREamicsTrainDataModule,\n",
")\n",
"from careamics.lightning_prediction import CAREamicsPredictionLoop\n",
"from careamics.utils.metrics import psnr"
]
},
Expand Down Expand Up @@ -121,6 +121,16 @@
"Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# N2V2 requires changes to the UNet model and to the Dataset (augmentations)\n",
"use_n2v2 = False # change to True to use N2V2"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -131,7 +141,7 @@
" algorithm=\"n2v\",\n",
" loss=\"n2v\",\n",
" architecture=\"UNet\",\n",
" model_parameters={\"n2v2\": True},\n",
" model_parameters={\"n2v2\": use_n2v2},\n",
" optimizer_parameters={\"lr\": 1e-3},\n",
" lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n",
")"
Expand All @@ -142,7 +152,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize the datamodule"
"### Initialize the datamodule\n",
"\n",
"The data module can take a `Path` or `str` to a folder or file, or a `np.ndarray`.\n",
"\n",
"For custom types, you need to pass a read function and an extension_filter."
]
},
{
Expand All @@ -152,13 +166,16 @@
"outputs": [],
"source": [
"train_data_module = CAREamicsTrainDataModule(\n",
" train_path=train_path,\n",
" val_path=val_path,\n",
" data_type=\"tiff\",\n",
" train_data=train_path,\n",
" val_data=val_path,\n",
" data_type=\"tiff\", # to use np.ndarray, set data_type to \"array\"\n",
" patch_size=(64, 64),\n",
" axes=\"SYX\",\n",
" batch_size=128,\n",
" dataloader_params={\"num_workers\": 4},\n",
" use_n2v2=use_n2v2,\n",
" struct_n2v_axis=\"none\", # choice between \"horizontal\", \"vertical\", or \"none\" (no structN2V)\n",
" struct_n2v_span=7,\n",
")"
]
},
Expand Down Expand Up @@ -205,9 +222,10 @@
"outputs": [],
"source": [
"pred_data_module = CAREamicsPredictDataModule(\n",
" pred_path=test_path,\n",
" pred_data=test_path,\n",
" data_type=\"tiff\",\n",
" tile_size=(256, 256),\n",
" tile_overlap=(48, 48),\n",
" axes=\"YX\",\n",
" batch_size=1,\n",
" tta_transforms=True,\n",
Expand All @@ -221,7 +239,7 @@
"source": [
"### Run prediction\n",
"\n",
"We need to specify the path to the data we want to denoise"
"First, we want to use CAREamics prediction loop, which allows tiling:"
]
},
{
Expand All @@ -230,16 +248,15 @@
"metadata": {},
"outputs": [],
"source": [
"tiled_loop = CAREamicsFiring(trainer)"
"tiled_loop = CAREamicsPredictionLoop(trainer)\n",
"trainer.predict_loop = tiled_loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"trainer.predict_loop = tiled_loop"
"Then, we predict using the datamodule."
]
},
{
Expand Down Expand Up @@ -359,7 +376,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.8.17"
},
"vscode": {
"interpreter": {
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .config.support import SupportedAlgorithm
from .lightning_datamodule import CAREamicsClay, CAREamicsWood
from .lightning_module import CAREamicsKiln
from .lightning_prediction import CAREamicsFiring
from .lightning_prediction import CAREamicsPredictionLoop
from .utils import check_path_exists, get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -208,7 +208,7 @@ def __init__(
)

# change the prediction loop, necessary for tiled prediction
self.trainer.predict_loop = CAREamicsFiring(self.trainer)
self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)

def _define_callbacks(self) -> List[Callback]:
"""Define the callbacks for the training loop.
Expand Down
4 changes: 3 additions & 1 deletion src/careamics/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,9 +811,9 @@ def __init__(
pred_data: Union[str, Path, np.ndarray],
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
tile_size: List[int],
tile_overlap: List[int],
axes: str,
batch_size: int,
tile_overlap: Optional[List[int]] = None,
tta_transforms: bool = True,
mean: Optional[float] = None,
std: Optional[float] = None,
Expand Down Expand Up @@ -858,6 +858,8 @@ def __init__(
dataloader_params : dict, optional
Pytorch dataloader parameters, by default {}.
"""
if tile_overlap is None:
tile_overlap = [48, 48]
if dataloader_params is None:
dataloader_params = {}
prediction_dict = {
Expand Down
2 changes: 1 addition & 1 deletion src/careamics/lightning_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from careamics.utils import denormalize


class CAREamicsFiring(L.loops._PredictionLoop):
class CAREamicsPredictionLoop(L.loops._PredictionLoop):
"""Predict loop for tiles-based prediction."""

# def _predict_step(self, batch, batch_idx, dataloader_idx, dataloader_iter):
Expand Down

0 comments on commit 8a4bf65

Please sign in to comment.