Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update BMZ README #322

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
rev: v0.7.2
hooks:
- id: ruff
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*"
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^scripts/.*"
args: [--fix, --target-version, py38]

- repo: https://github.com/psf/black
Expand All @@ -31,7 +31,7 @@ repos:
- id: mypy
files: "^src/"
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/config/likelihood_model.py|^src/careamics/losses/loss_factory.py|^src/careamics/losses/lvae/losses.py"
args: ['--config-file', 'mypy.ini']
args: ["--config-file", "mypy.ini"]
additional_dependencies:
- numpy
- types-PyYAML
Expand All @@ -42,7 +42,7 @@ repos:
rev: v1.8.0
hooks:
- id: numpydoc-validation
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*"
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*|^scripts/.*"

# # jupyter linting and formatting
# - repo: https://github.com/nbQA-dev/nbQA
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies = [
'typer==0.12.3',
'scikit-image<=0.23.2',
'zarr<3.0.0',
'pillow<=10.3.0',
]

[project.optional-dependencies]
Expand Down
29 changes: 29 additions & 0 deletions scripts/export_bmz_readme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python
"""Export a README file for the bioimage model zoo."""
from pathlib import Path

from careamics.config import create_n2v_configuration
from careamics.model_io.bioimage._readme_factory import readme_factory


def main():
# create configuration
config = create_n2v_configuration(
experiment_name="export_bmz_readme",
data_type="array",
axes="YX",
patch_size=(64, 64),
batch_size=2,
num_epochs=10,
)
# export README
readme_path = readme_factory(
config=config, careamics_version="0.1.0", data_description="Mydata"
)

# copy file to __file__
readme_path.rename(Path(__file__).parent / "README.md")


if __name__ == "__main__":
main()
Empty file added scripts/export_covers.py
Empty file.
24 changes: 16 additions & 8 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,11 @@ def export_to_bmz(
friendly_model_name: str,
input_array: NDArray,
authors: list[dict],
general_description: str = "",
general_description: str,
data_description: str,
covers: Optional[list[Union[Path, str]]] = None,
channel_names: Optional[list[str]] = None,
data_description: Optional[str] = None,
model_version: str = "0.1.0",
) -> None:
"""Export the model to the BioImage Model Zoo format.

Expand Down Expand Up @@ -898,11 +900,15 @@ def export_to_bmz(
authors : list of dict
List of authors of the model.
general_description : str
General description of the model, used in the metadata of the BMZ archive.
channel_names : list of str, optional
Channel names, by default None.
data_description : str, optional
Description of the data, by default None.
General description of the model used in the BMZ metadata.
data_description : str
Description of the data the model was trained on.
covers : list of pathlib.Path or str, default=None
Paths to the cover images.
channel_names : list of str, default=None
Channel names.
model_version : str, default="0.1.0"
Version of the model.
"""
# TODO: add in docs that it is expected that input_array dimensions match
# those in data_config
Expand All @@ -921,11 +927,13 @@ def export_to_bmz(
path_to_archive=path_to_archive,
model_name=friendly_model_name,
general_description=general_description,
data_description=data_description,
authors=authors,
input_array=input_array,
output_array=output,
covers=covers,
channel_names=channel_names,
data_description=data_description,
model_version=model_version,
)

def get_losses(self) -> dict[str, list]:
Expand Down
55 changes: 23 additions & 32 deletions src/careamics/model_io/bioimage/_readme_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Functions used to create a README.md file for BMZ export."""

from pathlib import Path
from typing import Optional

import yaml

Expand All @@ -28,7 +27,7 @@ def _yaml_block(yaml_str: str) -> str:
def readme_factory(
config: Configuration,
careamics_version: str,
data_description: Optional[str] = None,
data_description: str,
) -> Path:
"""Create a README file for the model.

Expand All @@ -41,18 +40,14 @@ def readme_factory(
CAREamics configuration.
careamics_version : str
CAREamics version.
data_description : Optional[str], optional
Description of the data, by default None.
data_description : str
Description of the data.

Returns
-------
Path
Path to the README file.
"""
algorithm = config.algorithm_config
training = config.training_config
data = config.data_config

# create file
# TODO use tempfile as in the bmz_io module
with cwd(get_careamics_home()):
Expand All @@ -65,42 +60,38 @@ def readme_factory(

description = [f"# {algorithm_pretty_name}\n\n"]

# data description
description.append("## Data description\n\n")
description.append(data_description)
description.append("\n\n")

# algorithm description
description.append("Algorithm description:\n\n")
description.append("## Algorithm description:\n\n")
description.append(config.get_algorithm_description())
description.append("\n\n")

# algorithm details
# configuration description
description.append("## Configuration\n\n")

description.append(
f"{algorithm_flavour} was trained using CAREamics (version "
f"{careamics_version}) with the following algorithm "
f"parameters:\n\n"
)
description.append(
_yaml_block(yaml.dump(algorithm.model_dump(exclude_none=True)))
f"{careamics_version}) using the following configuration:\n\n"
)
description.append("\n\n")

# data description
description.append("## Data description\n\n")
if data_description is not None:
description.append(data_description)
description.append("\n\n")

description.append("The data was processed using the following parameters:\n\n")

description.append(_yaml_block(yaml.dump(data.model_dump(exclude_none=True))))
description.append(_yaml_block(yaml.dump(config.model_dump(exclude_none=True))))
description.append("\n\n")

# training description
description.append("## Training description\n\n")

description.append("The model was trained using the following parameters:\n\n")
# validation
description.append("## Validation\n\n")

description.append(
_yaml_block(yaml.dump(training.model_dump(exclude_none=True)))
"In order to validate the model, we encourage users to acquire a "
"test dataset with ground-truth data. Comparing the ground-truth data "
"with the prediction allows unbiased evaluation of the model performances. "
"In the absence of ground-truth, inspecting the residual image (difference "
"between input and predicted image) can be helpful to identify "
"whether real signal is removed from the input image.\n\n"
)
description.append("\n\n")

# references
reference = config.get_algorithm_references()
Expand All @@ -113,7 +104,7 @@ def readme_factory(
description.append(
"## Links\n\n"
"- [CAREamics repository](https://github.com/CAREamics/careamics)\n"
"- [CAREamics documentation](https://careamics.github.io/latest/)\n"
"- [CAREamics documentation](https://careamics.github.io/)\n"
)

readme.write_text("".join(description))
Expand Down
171 changes: 171 additions & 0 deletions src/careamics/model_io/bioimage/cover_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Convenience function to create covers for the BMZ."""

from pathlib import Path

import numpy as np
from numpy.typing import NDArray
from PIL import Image

color_palette = np.array(
[
np.array([255, 195, 0]), # grey
np.array([189, 226, 240]),
np.array([96, 60, 76]),
np.array([193, 225, 193]),
]
)


def _get_norm_slice(array: NDArray) -> NDArray:
"""Get the normalized middle slice of a 4D or 5D array (SC(Z)YX).

Parameters
----------
array : NDArray
Array from which to get the middle slice.

Returns
-------
NDArray
Normalized middle slice of the input array.
"""
if array.ndim not in (4, 5):
raise ValueError("Array must be 4D or 5D.")

channels = array.shape[1] > 1
z_stack = array.ndim == 5

# get slice
if z_stack:
array_slice = array[0, :, array.shape[2] // 2, ...]
else:
array_slice = array[0, ...]

# channels
if channels:
array_slice = np.moveaxis(array_slice, 0, -1)
else:
array_slice = array_slice[0, ...]

# normalize
array_slice = (
255
* (array_slice - array_slice.min())
/ (array_slice.max() - array_slice.min())
)

return array_slice.astype(np.uint8)


def _four_channel_image(array: NDArray) -> Image:
"""Convert 4-channel array to Image.

Parameters
----------
array : NDArray
Normalized array to convert.

Returns
-------
Image
Converted array.
"""
colors = color_palette[np.newaxis, np.newaxis, :, :]
four_c_array = np.sum(array[..., :4, np.newaxis] * colors, axis=-2).astype(np.uint8)

return Image.fromarray(four_c_array).convert("RGB")


def _convert_to_image(original_shape: tuple[int, ...], array: NDArray) -> Image:
"""Convert to Image.

Parameters
----------
original_shape : tuple
Original shape of the array.
array : NDArray
Normalized array to convert.

Returns
-------
Image
Converted array.
"""
n_channels = original_shape[1]

if n_channels > 1:
if n_channels == 3:
return Image.fromarray(array).convert("RGB")
elif n_channels == 2:
# add an empty channel to the numpy array
array = np.concatenate([np.zeros_like(array[..., 0:1]), array], axis=-1)

return Image.fromarray(array).convert("RGB")
else: # more than 4
return _four_channel_image(array[..., :4])
else:
return Image.fromarray(array).convert("L").convert("RGB")


def create_cover(directory: Path, array_in: NDArray, array_out: NDArray) -> Path:
"""Create a cover image from input and output arrays.

Input and output arrays are expected to be SC(Z)YX. For images with a Z
dimension, the middle slice is taken.

Parameters
----------
directory : Path
Directory in which to save the cover.
array_in : numpy.ndarray
Array from which to create the cover image.
array_out : numpy.ndarray
Array from which to create the cover image.

Returns
-------
Path
Path to the saved cover image.
"""
# extract slice and normalize arrays
slice_in = _get_norm_slice(array_in)
slice_out = _get_norm_slice(array_out)

horizontal_split = slice_in.shape[-1] == slice_out.shape[-1]
if not horizontal_split:
if slice_in.shape[-2] != slice_out.shape[-2]:
raise ValueError("Input and output arrays have different shapes.")

# convert to Image
image_in = _convert_to_image(array_in.shape, slice_in)
image_out = _convert_to_image(array_out.shape, slice_out)

# split horizontally or vertically
if horizontal_split:
width = image_in.width // 2

cover = Image.new("RGB", (image_in.width, image_in.height))
cover.paste(image_in.crop((0, 0, width, image_in.height)), (0, 0))
cover.paste(
image_out.crop(
(image_in.width - width, 0, image_in.width, image_in.height)
),
(width, 0),
)
else:
height = image_in.height // 2

cover = Image.new("RGB", (image_in.width, image_in.height))
cover.paste(image_in.crop((0, 0, image_in.width, height)), (0, 0))
cover.paste(
image_out.crop(
(0, image_in.height - height, image_in.width, image_in.height)
),
(0, height),
)

# save
cover_path = directory / "cover.png"
cover.save(cover_path)

return cover_path
Loading
Loading