Skip to content

Commit

Permalink
Minor update medical imaging datasets for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jan 16, 2025
1 parent 4fa4f6d commit ac5892d
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 130 deletions.
112 changes: 84 additions & 28 deletions torch_em/data/datasets/medical/acdc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
"""The ACDC dataset contains annotations for multi-structure segmentation in cardiac MRI.
The labels have the following mapping:
- 0 (background), 1 (right ventricle cavity),2 (myocardium), 3 (left ventricle cavity)
The database is located at
https://humanheart-project.creatis.insa-lyon.fr/database/#collection/637218c173e9f0047faa00fb
The dataset is from the publication https://doi.org/10.1109/TMI.2018.2837502.
Please cite it if you use this dataset for a publication.
"""

import os
from glob import glob
from natsort import natsorted
from typing import Union, Tuple
from typing import Union, Tuple, Literal, List

from torch.utils.data import Dataset, DataLoader

import torch_em

Expand All @@ -13,26 +27,50 @@
CHECKSUM = "2787e08b0d3525cbac710fc3bdf69ee7c5fd7446472e49db8bc78548802f6b5e"


def get_acdc_data(path, download):
os.makedirs(path, exist_ok=True)
def get_acdc_data(path: Union[os.PathLike, str], download: bool = False) -> str:
"""Download the ACDC dataset.
Args:
path: Filepath to a folder where the data is downloaded for further processing.
download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downlaoded.
"""
zip_path = os.path.join(path, "ACDC.zip")
trg_dir = os.path.join(path, "ACDC")
if os.path.exists(trg_dir):
return trg_dir

os.makedirs(path, exist_ok=True)

util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
util.unzip(zip_path=zip_path, dst=path, remove=False)

return trg_dir


def _get_acdc_paths(path, split, download):
def get_acdc_paths(
path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False
) -> Tuple[List[str], List[str]]:
"""Get paths to the ACDC data.
Args:
path: Filepath to a folder where the data is downloaded for further processing.
download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data.
List of filepaths for the label data.
"""
root_dir = get_acdc_data(path=path, download=download)

if split == "train":
input_dir = os.path.join(root_dir, "database", "training")
else:
elif split == "test":
input_dir = os.path.join(root_dir, "database", "testing")
else:
raise ValueError(f"'{split}' is not a valid data split.")

all_patient_dirs = natsorted(glob(os.path.join(input_dir, "patient*")))

Expand All @@ -53,26 +91,32 @@ def _get_acdc_paths(path, split, download):

def get_acdc_dataset(
path: Union[os.PathLike, str],
split: str,
patch_shape: Tuple[int, int],
patch_shape: Tuple[int, ...],
split: Literal["train", "test"],
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
"""Dataset fir multi-structure segmentation in cardiac MRI.
The labels have the following mapping:
- 0 (background), 1 (right ventricle cavity),2 (myocardium), 3 (left ventricle cavity)
The database is located at
https://humanheart-project.creatis.insa-lyon.fr/database/#collection/637218c173e9f0047faa00fb
The dataset is from the publication https://doi.org/10.1109/TMI.2018.2837502
Please cite it if you use this dataset for a publication.
) -> Dataset:
"""Get the ACDC dataset for cardiac structure segmentation.
Args:
path: Filepath to a folder where the data is downloaded for further processing.
patch_shape: The patch shape to use for training.
split: The choice of data split.
resize_inputs: Whether to resize inputs to the desired patch shape.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
Returns:
The segmentation dataset.
"""
assert split in ["train", "test"], f"{split} is not a valid split."
image_paths, gt_paths = get_acdc_paths(path, split, download)

image_paths, gt_paths = _get_acdc_paths(path=path, split=split, download=download)
if resize_inputs:
resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
)

all_datasets = []
for image_path, gt_path in zip(image_paths, gt_paths):
Expand All @@ -92,15 +136,27 @@ def get_acdc_dataset(

def get_acdc_loader(
path: Union[os.PathLike, str],
split: str,
patch_shape: Tuple[int, int],
batch_size: int,
patch_shape: Tuple[int, ...],
split: Literal["train", "test"],
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
"""Dataloader for multi-structure segmentation in cardiac MRI, See `get_acdc_dataset` for details.
) -> DataLoader:
"""Get the ACDC dataloader for cardiac structure segmentation.
Args:
path: Filepath to a folder where the data is downloaded for further processing.
batch_size: The batch size for training.
patch_shape: The patch shape to use for training.
split: The choice of data split.
resize_inputs: Whether to resize inputs to the desired patch shape.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
Returns:
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_acdc_dataset(path=path, split=split, patch_shape=patch_shape, download=download, **ds_kwargs)
loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
return loader
dataset = get_acdc_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
18 changes: 7 additions & 11 deletions torch_em/data/datasets/medical/acouslic_ai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""The Acouslic AI dataset contains annotations for fetal segmentation
in ultrasound images.
"""The Acouslic AI dataset contains annotations for fetal segmentation in ultrasound images.
This dataset is from the challenge: https://acouslic-ai.grand-challenge.org/.
Please cite the challenge if you use this dataset for your publication.
Expand Down Expand Up @@ -31,12 +30,12 @@ def get_acouslic_ai_data(path: Union[os.PathLike, str], download: bool = False)
Returns:
Filepath where the data is downlaoded.
"""
os.makedirs(path, exist_ok=True)

data_dir = os.path.join(path, "data")
if os.path.exists(data_dir):
return data_dir

os.makedirs(path, exist_ok=True)

zip_path = os.path.join(path, "acouslic-ai-train-set.zip")
util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
util.unzip(zip_path=zip_path, dst=data_dir, remove=False)
Expand Down Expand Up @@ -102,8 +101,8 @@ def get_acouslic_ai_dataset(

def get_acouslic_ai_loader(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, ...],
batch_size: int,
patch_shape: Tuple[int, ...],
resize_inputs: bool = False,
download: bool = False,
**kwargs
Expand All @@ -112,8 +111,8 @@ def get_acouslic_ai_loader(
Args:
path: Filepath to a folder where the data is downloaded for further processing.
patch_shape: The patch shape to use for training.
batch_size: The batch size for training.
patch_shape: The patch shape to use for training.
resize_inputs: Whether to resize inputs to the desired patch shape.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
Expand All @@ -122,8 +121,5 @@ def get_acouslic_ai_loader(
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_acouslic_ai_dataset(
path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs
)
loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
return loader
dataset = get_acouslic_ai_dataset(path, patch_shape, resize_inputs, download, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
19 changes: 11 additions & 8 deletions torch_em/data/datasets/medical/amos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""The AMOS dataset contains annotations for abdominal multi-organ segmentation
in CT and MRI scans.
"""The AMOS dataset contains annotations for abdominal multi-organ segmentation in CT and MRI scans.
This dataset is located at https://doi.org/10.5281/zenodo.7155725.
The dataset is from AMOS 2022 Challenge https://doi.org/10.48550/arXiv.2206.08023.
Expand All @@ -11,6 +10,8 @@
from pathlib import Path
from typing import Union, Tuple, Optional, Literal, List

from torch.utils.data import Dataset, DataLoader

import torch_em

from .. import util
Expand Down Expand Up @@ -111,14 +112,15 @@ def get_amos_dataset(
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
) -> Dataset:
"""Get the AMOS dataset for abdominal multi-organ segmentation in CT and MRI scans.
Args:
path: Filepath to a folder where the data is downloaded for further processing.
patch_shape: The patch shape to use for traiing.
split: The choice of data split.
modality: The choice of imaging modality.
resize_inputs: Whether to resize the inputs.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
Expand Down Expand Up @@ -146,14 +148,14 @@ def get_amos_dataset(

def get_amos_loader(
path: Union[os.PathLike, str],
split: str,
patch_shape: Tuple[int, ...],
batch_size: int,
modality: Optional[str] = None,
patch_shape: Tuple[int, ...],
split: Literal['train', 'val', 'test'],
modality: Optional[Literal['CT', 'MRI']] = None,
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
) -> DataLoader:
"""Get the AMOS dataloader for abdominal multi-organ segmentation in CT and MRI scans.
Args:
Expand All @@ -162,12 +164,13 @@ def get_amos_loader(
patch_shape: The patch shape to use for training.
split: The choice of data split.
modality: The choice of imaging modality.
resize_inputs: Whether to resize the inputs.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
Returns:
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_amos_dataset(path, split, patch_shape, modality, resize_inputs, download, **ds_kwargs)
dataset = get_amos_dataset(path, patch_shape, split, modality, resize_inputs, download, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Loading

0 comments on commit ac5892d

Please sign in to comment.