Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Medical Segmentation Decathlon dataset #224

Merged
merged 12 commits into from
Jun 4, 2024
22 changes: 22 additions & 0 deletions scripts/datasets/medical/check_msd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_msd_loader


MSD_ROOT = "/media/anwai/ANWAI/data/msd"


def check_msd():
loader = get_msd_loader(
path=MSD_ROOT,
patch_shape=(1, 512, 512),
batch_size=2,
ndim=2,
download=True,
task_names="braintumour",
)
print(f"Length of the loader: {len(loader)}")
check_loader(loader, 8)


if __name__ == "__main__":
check_msd()
1 change: 1 addition & 0 deletions torch_em/data/datasets/medical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .busi import get_busi_dataset, get_busi_loader
from .camus import get_camus_dataset, get_camus_loader
from .drive import get_drive_dataset, get_drive_loader
from .msd import get_msd_dataset, get_msd_loader
from .papila import get_papila_dataset, get_papila_loader
from .plethora import get_plethora_dataset, get_plethora_loader
from .sa_med2d import get_sa_med2d_dataset, get_sa_med2d_loader
Expand Down
148 changes: 148 additions & 0 deletions torch_em/data/datasets/medical/msd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import os
from glob import glob
from pathlib import Path
from typing import Tuple, List, Union

import torch_em

from .. import util
from ....data import ConcatDataset


URL = {
"braintumour": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task01_BrainTumour.tar",
"heart": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task02_Heart.tar",
"liver": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task03_Liver.tar",
"hippocampus": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar",
"prostate": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task05_Prostate.tar",
"lung": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tar",
"pancreas": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task07_Pancreas.tar",
"hepaticvessel": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task08_HepaticVessel.tar",
"spleen": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar",
"colon": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task10_Colon.tar",
}

CHECKSUM = {
"braintumour": "d423911308d2ae5396d9c6bf4fad2b68cfde2dd09044269da9c0d639c22753c4",
"heart": "4277dc6dfe100142aa8060e895f6ff0f81c5b733703ea250bd294df8f820bcba",
"liver": "4007d9db1acda850d57a6ceb2b3998b7a0d43f8ad5a3f740dc38bc0cb8b7a2c5",
"hippocampus": "282d808a3e84e5a52f090d9dd4c0b0057b94a6bd51ad41569aef5ff303287771",
"prostate": "8cbbd7147691109b880ff8774eb6ab26704b1be0935482e7996a36a4ed31ec79",
"lung": "f782cd09da9cf7a3128475d4a53650d371db10f0427aa76e166fccfcb2654161",
"pancreas": "e40181a0229ca85c2588d6ebb90fa6674f84eb1e66f0f968cda088d011769732",
"hepaticvessel": "ee880799f12e3b6e1ef2f8645f6626c5b39de77a4f1eae6f496c25fbf306ba04",
"spleen": "dfeba347daae4fb08c38f4d243ab606b28b91b206ffc445ec55c35489fa65e60",
"colon": "a26bfd23faf2de703f5a51a262cd4e2b9774c47e7fb86f0e0a854f8446ec2325",
}

FILENAMES = {
"braintumour": "Task01_BrainTumour.tar",
"heart": "Task02_Heart.tar",
"liver": "Task03_Liver.tar",
"hippocampus": "Task04_Hippocampus.tar",
"prostate": "Task05_Prostate.tar",
"lung": "Task06_Lung.tar",
"pancreas": "Task07_Pancreas.tar",
"hepaticvessel": "Task08_HepaticVessel.tar",
"spleen": "Task09_Spleen.tar",
"colon": "Task10_Colon.tar",
}


def get_msd_data(path, task_name, download):
os.makedirs(path, exist_ok=True)

data_dir = os.path.join(path, "data", task_name)
if os.path.exists(data_dir):
return data_dir

fpath = os.path.join(path, FILENAMES[task_name])

util.download_source(path=fpath, url=URL[task_name], download=download, checksum=None)
util.unzip_tarfile(tar_path=fpath, dst=data_dir, remove=False)

return data_dir


def get_msd_dataset(
path: str,
patch_shape: Tuple[int, ...],
ndim: int,
task_names: Union[str, List[str]],
download: bool = False,
**kwargs
):
"""Dataset for semantic segmentation in 10 medical imaging datasets.

This dataset is from the Medical Segmentation Decathlon Challenge:
- Antonelli et al. - https://doi.org/10.1038/s41467-022-30695-9
- Link - http://medicaldecathlon.com/

Please cite it if you use this dataset for a publication.

Args:
path: The path to prepare the dataset.
patch_shape: The patch shape (for 2d or 3d patches)
ndim: The dimensions of inputs (use `2` for getting `2d` patches, and `3` for getting 3d patches)
task_names: The names for the 10 different segmentation tasks (see the challenge website for further details):
1. tasks with 1 modality inputs are: heart, liver, hippocampus, lung, pancreas, hepaticvessel, spleen, colon
2. tasks with multi-modality inputs are:
- braintumour: with 4 modality (channel) inputs
- prostate: with 2 modality (channel) inputs
download: Downloads the dataset

Here's an example for how to pass different tasks:
```python
# we want to get datasets for one task, eg. "heart"
task_names = ["heart"]

# we want to get datasets for multiple tasks
# NOTE 1: it's important to note that datasets with similar number of modality (channels) can be paired together.
# to use different datasets together, you need to use "raw_transform" to update inputs per dataset
# to pair as desired patch shapes per batch.
# Example 1: "heart", "liver", "lung" all have one modality inputs
task_names = ["heart", "lung", "liver"]

# Example 2: "braintumour" and "prostate" have multi-modal inputs, however the no. of modalities are not equal.
# hence, you can use only one at a time.
task_names = ["prostate"]
```
"""
if isinstance(task_names, str):
task_names = [task_names]

_datasets = []
for task_name in task_names:
data_dir = get_msd_data(path, task_name, download)
image_paths = glob(os.path.join(data_dir, Path(FILENAMES[task_name]).stem, "imagesTr", "*.nii.gz"))
label_paths = glob(os.path.join(data_dir, Path(FILENAMES[task_name]).stem, "labelsTr", "*.nii.gz"))

if task_name in ["braintumour", "prostate"]:
kwargs["with_channels"] = True

this_dataset = torch_em.default_segmentation_dataset(
raw_paths=image_paths,
raw_key="data",
label_paths=label_paths,
label_key="data",
patch_shape=patch_shape,
ndim=ndim,
**kwargs
)
_datasets.append(this_dataset)

return ConcatDataset(*_datasets)


def get_msd_loader(
path, patch_shape, batch_size, ndim, task_names, download=False, **kwargs
):
"""Dataloader for semantic segmentation from 10 highly variable medical segmentation tasks.
See `get_msd_dataset` for details.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(
torch_em.default_segmentation_dataset, **kwargs
)
ds = get_msd_dataset(path, patch_shape, ndim, task_names, download, **ds_kwargs)
loader = torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
return loader
21 changes: 21 additions & 0 deletions torch_em/data/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,27 @@ def update_kwargs(kwargs, key, value, msg=None):
return kwargs


def unzip_tarfile(tar_path, dst, remove=True):
import tarfile

if tar_path.endswith(".tar.gz"):
access_mode = "r:gz"
elif tar_path.endswith(".tar"):
access_mode = "r:"
else:
raise ValueError(
"The provided file isn't a supported archive to unpack. ",
f"Please check the file: {tar_path}"
)

tar = tarfile.open(tar_path, access_mode)
tar.extractall(dst)
tar.close()

if remove:
os.remove(tar_path)


def unzip(zip_path, dst, remove=True):
with zipfile.ZipFile(zip_path, "r") as f:
f.extractall(dst)
Expand Down