From 8dcc447544ca03646430e3a0761da1562f8a9af5 Mon Sep 17 00:00:00 2001 From: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:12:25 +0100 Subject: [PATCH] fix: Allow singleton channel in convenience functions (#265) ### Description Following https://github.com/CAREamics/careamics/issues/159, this PR allows creating a configuration with a singleton channel via the configuration convenience functions. - **What**: Allow singleton channel in convenience functions. - **Why**: In some rare cases, a singleton channel might be present in the data. - **How**: Change the `if` statements and error raising conditions of the convenience functions. ### Changes Made - **Modified**: `configuration_factory.py`. ### Related Issues - Fixes [Allow singleton channel in convenience functions](https://github.com/CAREamics/careamics/issues/159) --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [x] PR to the documentation exists (for bug fixes / features) --- src/careamics/cli/conf.py | 22 +++-- src/careamics/config/configuration_factory.py | 94 ++++++++++--------- tests/cli/test_conf.py | 3 +- tests/config/test_configuration_factory.py | 28 ++++++ 4 files changed, 93 insertions(+), 54 deletions(-) diff --git a/src/careamics/cli/conf.py b/src/careamics/cli/conf.py index 3c1f5c3e..01290f7f 100644 --- a/src/careamics/cli/conf.py +++ b/src/careamics/cli/conf.py @@ -3,7 +3,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Tuple +from typing import Optional, Tuple import click import typer @@ -154,8 +154,12 @@ def care( # numpydoc ignore=PR01 help="Loss function to use.", ), ] = "mae", - n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1, - n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1, + n_channels_in: Annotated[ + Optional[int], typer.Option(help="Number of channels in") + ] = None, + n_channels_out: Annotated[ + Optional[int], typer.Option(help="Number of channels out") + ] = None, logger: Annotated[ click.Choice, typer.Option( @@ -237,8 +241,12 @@ def n2n( # numpydoc ignore=PR01 help="Loss function to use.", ), ] = "mae", - n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1, - n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1, + n_channels_in: Annotated[ + Optional[int], typer.Option(help="Number of channels in") + ] = None, + n_channels_out: Annotated[ + Optional[int], typer.Option(help="Number of channels out") + ] = None, logger: Annotated[ click.Choice, typer.Option( @@ -312,8 +320,8 @@ def n2v( # numpydoc ignore=PR01 ] = True, use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False, n_channels: Annotated[ - int, typer.Option(help="Number of channels (in and out)") - ] = 1, + Optional[int], typer.Option(help="Number of channels (in and out)") + ] = None, roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11, masked_pixel_percentage: Annotated[ float, typer.Option(help="Percentage of pixels masked in each patch.") diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py index 6f3c7db5..59f1955b 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factory.py @@ -234,8 +234,8 @@ def _create_supervised_configuration( augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None, independent_channels: bool = True, loss: Literal["mae", "mse"] = "mae", - n_channels_in: int = 1, - n_channels_out: int = 1, + n_channels_in: Optional[int] = None, + n_channels_out: Optional[int] = None, logger: Literal["wandb", "tensorboard", "none"] = "none", model_params: Optional[dict] = None, dataloader_params: Optional[dict] = None, @@ -267,10 +267,10 @@ def _create_supervised_configuration( Whether to train all channels independently, by default False. loss : Literal["mae", "mse"], optional Loss function to use, by default "mae". - n_channels_in : int, optional - Number of channels in, by default 1. - n_channels_out : int, optional - Number of channels out, by default 1. + n_channels_in : int or None, default=None + Number of channels in. + n_channels_out : int or None, default=None + Number of channels out. logger : Literal["wandb", "tensorboard", "none"], optional Logger to use, by default "none". model_params : dict, optional @@ -282,19 +282,29 @@ def _create_supervised_configuration( ------- Configuration Configuration for training CARE or Noise2Noise. + + Raises + ------ + ValueError + If the number of channels is not specified when using channels. + ValueError + If the number of channels is specified but "C" is not in the axes. """ # if there are channels, we need to specify their number - if "C" in axes and n_channels_in == 1: - raise ValueError( - f"Number of channels in must be specified when using channels " - f"(got {n_channels_in} channel)." - ) - elif "C" not in axes and n_channels_in > 1: + if "C" in axes and n_channels_in is None: + raise ValueError("Number of channels in must be specified when using channels ") + elif "C" not in axes and (n_channels_in is not None and n_channels_in > 1): raise ValueError( f"C is not present in the axes, but number of channels is specified " f"(got {n_channels_in} channels)." ) + if n_channels_in is None: + n_channels_in = 1 + + if n_channels_out is None: + n_channels_out = n_channels_in + # augmentations transform_list = _list_augmentations(augmentations) @@ -327,8 +337,8 @@ def create_care_configuration( augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None, independent_channels: bool = True, loss: Literal["mae", "mse"] = "mae", - n_channels_in: int = 1, - n_channels_out: int = -1, + n_channels_in: Optional[int] = None, + n_channels_out: Optional[int] = None, logger: Literal["wandb", "tensorboard", "none"] = "none", model_params: Optional[dict] = None, dataloader_params: Optional[dict] = None, @@ -374,16 +384,16 @@ def create_care_configuration( and XYRandomRotate90 (in XY) to the images. independent_channels : bool, optional Whether to train all channels independently, by default False. - loss : Literal["mae", "mse"], optional - Loss function to use, by default "mae". - n_channels_in : int, optional - Number of channels in, by default 1. - n_channels_out : int, optional - Number of channels out, by default -1. - logger : Literal["wandb", "tensorboard", "none"], optional - Logger to use, by default "none". - model_params : dict, optional - UNetModel parameters, by default None. + loss : Literal["mae", "mse"], default="mae" + Loss function to use. + n_channels_in : int or None, default=None + Number of channels in. + n_channels_out : int or None, default=None + Number of channels out. + logger : Literal["wandb", "tensorboard", "none"], default="none" + Logger to use. + model_params : dict, default=None + UNetModel parameters. dataloader_params : dict, optional Parameters for the dataloader, see PyTorch notes, by default None. @@ -459,9 +469,6 @@ def create_care_configuration( ... n_channels_out=1 # if applicable ... ) """ - if n_channels_out == -1: - n_channels_out = n_channels_in - return _create_supervised_configuration( algorithm="care", experiment_name=experiment_name, @@ -491,8 +498,8 @@ def create_n2n_configuration( augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None, independent_channels: bool = True, loss: Literal["mae", "mse"] = "mae", - n_channels_in: int = 1, - n_channels_out: int = -1, + n_channels_in: Optional[int] = None, + n_channels_out: Optional[int] = None, logger: Literal["wandb", "tensorboard", "none"] = "none", model_params: Optional[dict] = None, dataloader_params: Optional[dict] = None, @@ -540,10 +547,10 @@ def create_n2n_configuration( Whether to train all channels independently, by default False. loss : Literal["mae", "mse"], optional Loss function to use, by default "mae". - n_channels_in : int, optional - Number of channels in, by default 1. - n_channels_out : int, optional - Number of channels out, by default -1. + n_channels_in : int or None, default=None + Number of channels in. + n_channels_out : int or None, default=None + Number of channels out. logger : Literal["wandb", "tensorboard", "none"], optional Logger to use, by default "none". model_params : dict, optional @@ -623,9 +630,6 @@ def create_n2n_configuration( ... n_channels_out=1 # if applicable ... ) """ - if n_channels_out == -1: - n_channels_out = n_channels_in - return _create_supervised_configuration( algorithm="n2n", experiment_name=experiment_name, @@ -655,7 +659,7 @@ def create_n2v_configuration( augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None, independent_channels: bool = True, use_n2v2: bool = False, - n_channels: int = 1, + n_channels: Optional[int] = None, roi_size: int = 11, masked_pixel_percentage: float = 0.2, struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none", @@ -727,8 +731,8 @@ def create_n2v_configuration( Whether to train all channels together, by default True. use_n2v2 : bool, optional Whether to use N2V2, by default False. - n_channels : int, optional - Number of channels (in and out), by default 1. + n_channels : int or None, default=None + Number of channels (in and out). roi_size : int, optional N2V pixel manipulation area, by default 11. masked_pixel_percentage : float, optional @@ -837,17 +841,17 @@ def create_n2v_configuration( ... ) """ # if there are channels, we need to specify their number - if "C" in axes and n_channels == 1: - raise ValueError( - f"Number of channels must be specified when using channels " - f"(got {n_channels} channel)." - ) - elif "C" not in axes and n_channels > 1: + if "C" in axes and n_channels is None: + raise ValueError("Number of channels must be specified when using channels.") + elif "C" not in axes and (n_channels is not None and n_channels > 1): raise ValueError( f"C is not present in the axes, but number of channels is specified " f"(got {n_channels} channel)." ) + if n_channels is None: + n_channels = 1 + # augmentations transform_list = _list_augmentations(augmentations) diff --git a/tests/cli/test_conf.py b/tests/cli/test_conf.py index 7c5417ca..534f9dc6 100644 --- a/tests/cli/test_conf.py +++ b/tests/cli/test_conf.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import pytest @@ -33,5 +32,5 @@ def test_conf(tmp_path: Path, algorithm: str): "1", ], ) - assert os.path.isfile(config_path) + assert config_path.is_file() assert result.exit_code == 0 diff --git a/tests/config/test_configuration_factory.py b/tests/config/test_configuration_factory.py index 01098ef4..e45ad85c 100644 --- a/tests/config/test_configuration_factory.py +++ b/tests/config/test_configuration_factory.py @@ -179,6 +179,34 @@ def test_supervised_configuration_error_with_channel_axes(): ) +def test_supervised_configuration_singleton_channel(): + """Test that no error is raised if channels are in axes, and the input channel is + 1.""" + _create_supervised_configuration( + algorithm="n2n", + experiment_name="test", + data_type="tiff", + axes="CYX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + n_channels_in=1, + ) + + +def test_supervised_configuration_no_channel(): + """Test that no error is raised without channel and number of inputs.""" + _create_supervised_configuration( + algorithm="n2n", + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + ) + + def test_supervised_configuration_error_without_channel_axes(): """Test that an error is raised if channels are not in axes, but the input channel number is specified and greater than 1."""