Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Lightning API #106

Merged
merged 11 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ repos:
additional_dependencies:
- numpy
- types-PyYAML
- types-setuptools

# check docstrings
- repo: https://github.com/numpy/numpydoc
Expand Down
54 changes: 34 additions & 20 deletions examples/2D/n2v/example_BSD68_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,17 @@
"source": [
"from pathlib import Path\n",
"\n",
"import torch\n",
"import albumentations as Aug\n",
"import matplotlib.pyplot as plt\n",
"import tifffile\n",
"from careamics_portfolio import PortfolioManager\n",
"from pytorch_lightning import Trainer\n",
"\n",
"from careamics import CAREamicsModule\n",
"from careamics.lightning_prediction import CAREamicsFiring\n",
"from careamics.ligthning_datamodule import (\n",
"from careamics.lightning_datamodule import (\n",
" CAREamicsPredictDataModule,\n",
" CAREamicsTrainDataModule,\n",
")\n",
"from careamics.transforms import N2VManipulate\n",
"from careamics.lightning_prediction import CAREamicsPredictionLoop\n",
"from careamics.utils.metrics import psnr"
]
},
Expand Down Expand Up @@ -124,6 +121,16 @@
"Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# N2V2 requires changes to the UNet model and to the Dataset (augmentations)\n",
"use_n2v2 = False # change to True to use N2V2"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -134,18 +141,22 @@
" algorithm=\"n2v\",\n",
" loss=\"n2v\",\n",
" architecture=\"UNet\",\n",
" model_parameters={\"n2v2\": True},\n",
" model_parameters={\"n2v2\": use_n2v2},\n",
" optimizer_parameters={\"lr\": 1e-3},\n",
" lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n",
")\n"
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize the datamodule"
"### Initialize the datamodule\n",
"\n",
"The data module can take a `Path` or `str` to a folder or file, or a `np.ndarray`.\n",
"\n",
"For custom types, you need to pass a read function and an extension_filter."
]
},
{
Expand All @@ -155,13 +166,16 @@
"outputs": [],
"source": [
"train_data_module = CAREamicsTrainDataModule(\n",
" train_path=train_path,\n",
" val_path=val_path,\n",
" data_type=\"tiff\",\n",
" train_data=train_path,\n",
" val_data=val_path,\n",
" data_type=\"tiff\", # to use np.ndarray, set data_type to \"array\"\n",
" patch_size=(64, 64),\n",
" axes=\"SYX\",\n",
" batch_size=128,\n",
" dataloader_params={\"num_workers\": 4},\n",
" use_n2v2=use_n2v2,\n",
" struct_n2v_axis=\"none\", # choice between \"horizontal\", \"vertical\", or \"none\" (no structN2V)\n",
" struct_n2v_span=7,\n",
")"
]
},
Expand Down Expand Up @@ -208,9 +222,10 @@
"outputs": [],
"source": [
"pred_data_module = CAREamicsPredictDataModule(\n",
" pred_path=test_path,\n",
" pred_data=test_path,\n",
" data_type=\"tiff\",\n",
" tile_size=(256, 256),\n",
" tile_overlap=(48, 48),\n",
" axes=\"YX\",\n",
" batch_size=1,\n",
" tta_transforms=True,\n",
Expand All @@ -224,7 +239,7 @@
"source": [
"### Run prediction\n",
"\n",
"We need to specify the path to the data we want to denoise"
"First, we want to use CAREamics prediction loop, which allows tiling:"
]
},
{
Expand All @@ -233,16 +248,15 @@
"metadata": {},
"outputs": [],
"source": [
"tiled_loop = CAREamicsFiring(trainer)"
"tiled_loop = CAREamicsPredictionLoop(trainer)\n",
"trainer.predict_loop = tiled_loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"trainer.predict_loop = tiled_loop"
"Then, we predict using the datamodule."
]
},
{
Expand Down Expand Up @@ -270,7 +284,7 @@
"source": [
"# Create a list of ground truth images\n",
"\n",
"gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]\n"
"gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]"
]
},
{
Expand Down Expand Up @@ -362,7 +376,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.8.17"
},
"vscode": {
"interpreter": {
Expand Down
9 changes: 5 additions & 4 deletions examples/DL4MIA_N2V.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@
"from careamics_portfolio import PortfolioManager\n",
"\n",
"from careamics import CAREamist\n",
"from careamics.config import create_n2v_training_configuration, create_n2v_inference_configuration"
"from careamics.config import (\n",
" create_n2v_training_configuration,\n",
")"
]
},
{
Expand Down Expand Up @@ -202,7 +204,7 @@
"metadata": {},
"outputs": [],
"source": [
"engine = CAREamist(source=training_config)\n"
"engine = CAREamist(source=training_config)"
]
},
{
Expand Down Expand Up @@ -273,8 +275,7 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
},
"orig_nbformat": 4
}
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down
45 changes: 0 additions & 45 deletions examples/careamics_api.py

This file was deleted.

169 changes: 0 additions & 169 deletions examples/careamics_lightning_api.ipynb

This file was deleted.

Loading
Loading