Skip to content

Commit

Permalink
Add Medical Segmentation Decathlon dataset (constantinpape#224)
Browse files Browse the repository at this point in the history
Add Medical Segmentation Decathlon dataset
  • Loading branch information
anwai98 committed Jun 7, 2024
1 parent 0d7ddc3 commit 59b1599
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 0 deletions.
22 changes: 22 additions & 0 deletions scripts/datasets/medical/check_msd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_msd_loader


MSD_ROOT = "/media/anwai/ANWAI/data/msd"


def check_msd():
loader = get_msd_loader(
path=MSD_ROOT,
patch_shape=(1, 512, 512),
batch_size=2,
ndim=2,
download=True,
task_names="braintumour",
)
print(f"Length of the loader: {len(loader)}")
check_loader(loader, 8)


if __name__ == "__main__":
check_msd()
1 change: 1 addition & 0 deletions torch_em/data/datasets/medical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .busi import get_busi_dataset, get_busi_loader
from .camus import get_camus_dataset, get_camus_loader
from .drive import get_drive_dataset, get_drive_loader
from .msd import get_msd_dataset, get_msd_loader
from .papila import get_papila_dataset, get_papila_loader
from .plethora import get_plethora_dataset, get_plethora_loader
from .sa_med2d import get_sa_med2d_dataset, get_sa_med2d_loader
Expand Down
148 changes: 148 additions & 0 deletions torch_em/data/datasets/medical/msd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import os
from glob import glob
from pathlib import Path
from typing import Tuple, List, Union

import torch_em

from .. import util
from ....data import ConcatDataset


URL = {
"braintumour": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task01_BrainTumour.tar",
"heart": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task02_Heart.tar",
"liver": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task03_Liver.tar",
"hippocampus": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar",
"prostate": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task05_Prostate.tar",
"lung": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tar",
"pancreas": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task07_Pancreas.tar",
"hepaticvessel": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task08_HepaticVessel.tar",
"spleen": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar",
"colon": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task10_Colon.tar",
}

CHECKSUM = {
"braintumour": "d423911308d2ae5396d9c6bf4fad2b68cfde2dd09044269da9c0d639c22753c4",
"heart": "4277dc6dfe100142aa8060e895f6ff0f81c5b733703ea250bd294df8f820bcba",
"liver": "4007d9db1acda850d57a6ceb2b3998b7a0d43f8ad5a3f740dc38bc0cb8b7a2c5",
"hippocampus": "282d808a3e84e5a52f090d9dd4c0b0057b94a6bd51ad41569aef5ff303287771",
"prostate": "8cbbd7147691109b880ff8774eb6ab26704b1be0935482e7996a36a4ed31ec79",
"lung": "f782cd09da9cf7a3128475d4a53650d371db10f0427aa76e166fccfcb2654161",
"pancreas": "e40181a0229ca85c2588d6ebb90fa6674f84eb1e66f0f968cda088d011769732",
"hepaticvessel": "ee880799f12e3b6e1ef2f8645f6626c5b39de77a4f1eae6f496c25fbf306ba04",
"spleen": "dfeba347daae4fb08c38f4d243ab606b28b91b206ffc445ec55c35489fa65e60",
"colon": "a26bfd23faf2de703f5a51a262cd4e2b9774c47e7fb86f0e0a854f8446ec2325",
}

FILENAMES = {
"braintumour": "Task01_BrainTumour.tar",
"heart": "Task02_Heart.tar",
"liver": "Task03_Liver.tar",
"hippocampus": "Task04_Hippocampus.tar",
"prostate": "Task05_Prostate.tar",
"lung": "Task06_Lung.tar",
"pancreas": "Task07_Pancreas.tar",
"hepaticvessel": "Task08_HepaticVessel.tar",
"spleen": "Task09_Spleen.tar",
"colon": "Task10_Colon.tar",
}


def get_msd_data(path, task_name, download):
os.makedirs(path, exist_ok=True)

data_dir = os.path.join(path, "data", task_name)
if os.path.exists(data_dir):
return data_dir

fpath = os.path.join(path, FILENAMES[task_name])

util.download_source(path=fpath, url=URL[task_name], download=download, checksum=None)
util.unzip_tarfile(tar_path=fpath, dst=data_dir, remove=False)

return data_dir


def get_msd_dataset(
path: str,
patch_shape: Tuple[int, ...],
ndim: int,
task_names: Union[str, List[str]],
download: bool = False,
**kwargs
):
"""Dataset for semantic segmentation in 10 medical imaging datasets.
This dataset is from the Medical Segmentation Decathlon Challenge:
- Antonelli et al. - https://doi.org/10.1038/s41467-022-30695-9
- Link - http://medicaldecathlon.com/
Please cite it if you use this dataset for a publication.
Args:
path: The path to prepare the dataset.
patch_shape: The patch shape (for 2d or 3d patches)
ndim: The dimensions of inputs (use `2` for getting `2d` patches, and `3` for getting 3d patches)
task_names: The names for the 10 different segmentation tasks (see the challenge website for further details):
1. tasks with 1 modality inputs are: heart, liver, hippocampus, lung, pancreas, hepaticvessel, spleen, colon
2. tasks with multi-modality inputs are:
- braintumour: with 4 modality (channel) inputs
- prostate: with 2 modality (channel) inputs
download: Downloads the dataset
Here's an example for how to pass different tasks:
```python
# we want to get datasets for one task, eg. "heart"
task_names = ["heart"]
# we want to get datasets for multiple tasks
# NOTE 1: it's important to note that datasets with similar number of modality (channels) can be paired together.
# to use different datasets together, you need to use "raw_transform" to update inputs per dataset
# to pair as desired patch shapes per batch.
# Example 1: "heart", "liver", "lung" all have one modality inputs
task_names = ["heart", "lung", "liver"]
# Example 2: "braintumour" and "prostate" have multi-modal inputs, however the no. of modalities are not equal.
# hence, you can use only one at a time.
task_names = ["prostate"]
```
"""
if isinstance(task_names, str):
task_names = [task_names]

_datasets = []
for task_name in task_names:
data_dir = get_msd_data(path, task_name, download)
image_paths = glob(os.path.join(data_dir, Path(FILENAMES[task_name]).stem, "imagesTr", "*.nii.gz"))
label_paths = glob(os.path.join(data_dir, Path(FILENAMES[task_name]).stem, "labelsTr", "*.nii.gz"))

if task_name in ["braintumour", "prostate"]:
kwargs["with_channels"] = True

this_dataset = torch_em.default_segmentation_dataset(
raw_paths=image_paths,
raw_key="data",
label_paths=label_paths,
label_key="data",
patch_shape=patch_shape,
ndim=ndim,
**kwargs
)
_datasets.append(this_dataset)

return ConcatDataset(*_datasets)


def get_msd_loader(
path, patch_shape, batch_size, ndim, task_names, download=False, **kwargs
):
"""Dataloader for semantic segmentation from 10 highly variable medical segmentation tasks.
See `get_msd_dataset` for details.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(
torch_em.default_segmentation_dataset, **kwargs
)
ds = get_msd_dataset(path, patch_shape, ndim, task_names, download, **ds_kwargs)
loader = torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
return loader
21 changes: 21 additions & 0 deletions torch_em/data/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,27 @@ def update_kwargs(kwargs, key, value, msg=None):
return kwargs


def unzip_tarfile(tar_path, dst, remove=True):
import tarfile

if tar_path.endswith(".tar.gz"):
access_mode = "r:gz"
elif tar_path.endswith(".tar"):
access_mode = "r:"
else:
raise ValueError(
"The provided file isn't a supported archive to unpack. ",
f"Please check the file: {tar_path}"
)

tar = tarfile.open(tar_path, access_mode)
tar.extractall(dst)
tar.close()

if remove:
os.remove(tar_path)


def unzip(zip_path, dst, remove=True):
with zipfile.ZipFile(zip_path, "r") as f:
f.extractall(dst)
Expand Down

0 comments on commit 59b1599

Please sign in to comment.