Skip to content

Commit

Permalink
fix: Allow singleton channel in convenience functions (#265)
Browse files Browse the repository at this point in the history
### Description

Following #159, this PR
allows creating a configuration with a singleton channel via the
configuration convenience functions.

- **What**: Allow singleton channel in convenience functions.
- **Why**: In some rare cases, a singleton channel might be present in
the data.
- **How**: Change the `if` statements and error raising conditions of
the convenience functions.

### Changes Made


- **Modified**: `configuration_factory.py`.

### Related Issues

- Fixes [Allow singleton channel in convenience
functions](#159)

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [x] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
jdeschamps authored Nov 12, 2024
1 parent 6ac24a7 commit 8dcc447
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 54 deletions.
22 changes: 15 additions & 7 deletions src/careamics/cli/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple
from typing import Optional, Tuple

import click
import typer
Expand Down Expand Up @@ -154,8 +154,12 @@ def care( # numpydoc ignore=PR01
help="Loss function to use.",
),
] = "mae",
n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
n_channels_in: Annotated[
Optional[int], typer.Option(help="Number of channels in")
] = None,
n_channels_out: Annotated[
Optional[int], typer.Option(help="Number of channels out")
] = None,
logger: Annotated[
click.Choice,
typer.Option(
Expand Down Expand Up @@ -237,8 +241,12 @@ def n2n( # numpydoc ignore=PR01
help="Loss function to use.",
),
] = "mae",
n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
n_channels_in: Annotated[
Optional[int], typer.Option(help="Number of channels in")
] = None,
n_channels_out: Annotated[
Optional[int], typer.Option(help="Number of channels out")
] = None,
logger: Annotated[
click.Choice,
typer.Option(
Expand Down Expand Up @@ -312,8 +320,8 @@ def n2v( # numpydoc ignore=PR01
] = True,
use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
n_channels: Annotated[
int, typer.Option(help="Number of channels (in and out)")
] = 1,
Optional[int], typer.Option(help="Number of channels (in and out)")
] = None,
roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
masked_pixel_percentage: Annotated[
float, typer.Option(help="Percentage of pixels masked in each patch.")
Expand Down
94 changes: 49 additions & 45 deletions src/careamics/config/configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def _create_supervised_configuration(
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
independent_channels: bool = True,
loss: Literal["mae", "mse"] = "mae",
n_channels_in: int = 1,
n_channels_out: int = 1,
n_channels_in: Optional[int] = None,
n_channels_out: Optional[int] = None,
logger: Literal["wandb", "tensorboard", "none"] = "none",
model_params: Optional[dict] = None,
dataloader_params: Optional[dict] = None,
Expand Down Expand Up @@ -267,10 +267,10 @@ def _create_supervised_configuration(
Whether to train all channels independently, by default False.
loss : Literal["mae", "mse"], optional
Loss function to use, by default "mae".
n_channels_in : int, optional
Number of channels in, by default 1.
n_channels_out : int, optional
Number of channels out, by default 1.
n_channels_in : int or None, default=None
Number of channels in.
n_channels_out : int or None, default=None
Number of channels out.
logger : Literal["wandb", "tensorboard", "none"], optional
Logger to use, by default "none".
model_params : dict, optional
Expand All @@ -282,19 +282,29 @@ def _create_supervised_configuration(
-------
Configuration
Configuration for training CARE or Noise2Noise.
Raises
------
ValueError
If the number of channels is not specified when using channels.
ValueError
If the number of channels is specified but "C" is not in the axes.
"""
# if there are channels, we need to specify their number
if "C" in axes and n_channels_in == 1:
raise ValueError(
f"Number of channels in must be specified when using channels "
f"(got {n_channels_in} channel)."
)
elif "C" not in axes and n_channels_in > 1:
if "C" in axes and n_channels_in is None:
raise ValueError("Number of channels in must be specified when using channels ")
elif "C" not in axes and (n_channels_in is not None and n_channels_in > 1):
raise ValueError(
f"C is not present in the axes, but number of channels is specified "
f"(got {n_channels_in} channels)."
)

if n_channels_in is None:
n_channels_in = 1

if n_channels_out is None:
n_channels_out = n_channels_in

# augmentations
transform_list = _list_augmentations(augmentations)

Expand Down Expand Up @@ -327,8 +337,8 @@ def create_care_configuration(
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
independent_channels: bool = True,
loss: Literal["mae", "mse"] = "mae",
n_channels_in: int = 1,
n_channels_out: int = -1,
n_channels_in: Optional[int] = None,
n_channels_out: Optional[int] = None,
logger: Literal["wandb", "tensorboard", "none"] = "none",
model_params: Optional[dict] = None,
dataloader_params: Optional[dict] = None,
Expand Down Expand Up @@ -374,16 +384,16 @@ def create_care_configuration(
and XYRandomRotate90 (in XY) to the images.
independent_channels : bool, optional
Whether to train all channels independently, by default False.
loss : Literal["mae", "mse"], optional
Loss function to use, by default "mae".
n_channels_in : int, optional
Number of channels in, by default 1.
n_channels_out : int, optional
Number of channels out, by default -1.
logger : Literal["wandb", "tensorboard", "none"], optional
Logger to use, by default "none".
model_params : dict, optional
UNetModel parameters, by default None.
loss : Literal["mae", "mse"], default="mae"
Loss function to use.
n_channels_in : int or None, default=None
Number of channels in.
n_channels_out : int or None, default=None
Number of channels out.
logger : Literal["wandb", "tensorboard", "none"], default="none"
Logger to use.
model_params : dict, default=None
UNetModel parameters.
dataloader_params : dict, optional
Parameters for the dataloader, see PyTorch notes, by default None.
Expand Down Expand Up @@ -459,9 +469,6 @@ def create_care_configuration(
... n_channels_out=1 # if applicable
... )
"""
if n_channels_out == -1:
n_channels_out = n_channels_in

return _create_supervised_configuration(
algorithm="care",
experiment_name=experiment_name,
Expand Down Expand Up @@ -491,8 +498,8 @@ def create_n2n_configuration(
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
independent_channels: bool = True,
loss: Literal["mae", "mse"] = "mae",
n_channels_in: int = 1,
n_channels_out: int = -1,
n_channels_in: Optional[int] = None,
n_channels_out: Optional[int] = None,
logger: Literal["wandb", "tensorboard", "none"] = "none",
model_params: Optional[dict] = None,
dataloader_params: Optional[dict] = None,
Expand Down Expand Up @@ -540,10 +547,10 @@ def create_n2n_configuration(
Whether to train all channels independently, by default False.
loss : Literal["mae", "mse"], optional
Loss function to use, by default "mae".
n_channels_in : int, optional
Number of channels in, by default 1.
n_channels_out : int, optional
Number of channels out, by default -1.
n_channels_in : int or None, default=None
Number of channels in.
n_channels_out : int or None, default=None
Number of channels out.
logger : Literal["wandb", "tensorboard", "none"], optional
Logger to use, by default "none".
model_params : dict, optional
Expand Down Expand Up @@ -623,9 +630,6 @@ def create_n2n_configuration(
... n_channels_out=1 # if applicable
... )
"""
if n_channels_out == -1:
n_channels_out = n_channels_in

return _create_supervised_configuration(
algorithm="n2n",
experiment_name=experiment_name,
Expand Down Expand Up @@ -655,7 +659,7 @@ def create_n2v_configuration(
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
independent_channels: bool = True,
use_n2v2: bool = False,
n_channels: int = 1,
n_channels: Optional[int] = None,
roi_size: int = 11,
masked_pixel_percentage: float = 0.2,
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
Expand Down Expand Up @@ -727,8 +731,8 @@ def create_n2v_configuration(
Whether to train all channels together, by default True.
use_n2v2 : bool, optional
Whether to use N2V2, by default False.
n_channels : int, optional
Number of channels (in and out), by default 1.
n_channels : int or None, default=None
Number of channels (in and out).
roi_size : int, optional
N2V pixel manipulation area, by default 11.
masked_pixel_percentage : float, optional
Expand Down Expand Up @@ -837,17 +841,17 @@ def create_n2v_configuration(
... )
"""
# if there are channels, we need to specify their number
if "C" in axes and n_channels == 1:
raise ValueError(
f"Number of channels must be specified when using channels "
f"(got {n_channels} channel)."
)
elif "C" not in axes and n_channels > 1:
if "C" in axes and n_channels is None:
raise ValueError("Number of channels must be specified when using channels.")
elif "C" not in axes and (n_channels is not None and n_channels > 1):
raise ValueError(
f"C is not present in the axes, but number of channels is specified "
f"(got {n_channels} channel)."
)

if n_channels is None:
n_channels = 1

# augmentations
transform_list = _list_augmentations(augmentations)

Expand Down
3 changes: 1 addition & 2 deletions tests/cli/test_conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from pathlib import Path

import pytest
Expand Down Expand Up @@ -33,5 +32,5 @@ def test_conf(tmp_path: Path, algorithm: str):
"1",
],
)
assert os.path.isfile(config_path)
assert config_path.is_file()
assert result.exit_code == 0
28 changes: 28 additions & 0 deletions tests/config/test_configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,34 @@ def test_supervised_configuration_error_with_channel_axes():
)


def test_supervised_configuration_singleton_channel():
"""Test that no error is raised if channels are in axes, and the input channel is
1."""
_create_supervised_configuration(
algorithm="n2n",
experiment_name="test",
data_type="tiff",
axes="CYX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
n_channels_in=1,
)


def test_supervised_configuration_no_channel():
"""Test that no error is raised without channel and number of inputs."""
_create_supervised_configuration(
algorithm="n2n",
experiment_name="test",
data_type="tiff",
axes="YX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
)


def test_supervised_configuration_error_without_channel_axes():
"""Test that an error is raised if channels are not in axes, but the input channel
number is specified and greater than 1."""
Expand Down

0 comments on commit 8dcc447

Please sign in to comment.