Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Toothfairy dataset #313

Merged
merged 2 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions scripts/datasets/medical/check_toothfairy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.data.datasets.medical import get_toothfairy_loader


ROOT = "/scratch/share/cidas/cca/data/toothfairy/"


def check_toothfairy():
loader = get_toothfairy_loader(
path=ROOT,
patch_shape=(1, 512, 512),
ndim=2,
batch_size=2,
sampler=MinInstanceSampler()
)

check_loader(loader, 8, plt=True, save_path="./toothfairy.png")


check_toothfairy()
1 change: 1 addition & 0 deletions torch_em/data/datasets/medical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
from .sa_med2d import get_sa_med2d_dataset, get_sa_med2d_loader
from .sega import get_sega_dataset, get_sega_loader
from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader
from .toothfairy import get_toothfairy_dataset, get_toothfairy_loader
from .uwaterloo_skin import get_uwaterloo_skin_dataset, get_uwaterloo_skin_loader
88 changes: 88 additions & 0 deletions torch_em/data/datasets/medical/toothfairy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
from glob import glob
from tqdm import tqdm
from natsort import natsorted

import numpy as np
import nibabel as nib

import torch_em

from .. import util


def get_toothfairy_data(path, download):
"""Automatic download is not possible.
"""
if download:
raise NotImplementedError

data_dir = os.path.join(path, "ToothFairy_Dataset", "Dataset")
return data_dir


def _get_toothfairy_paths(path, download):
data_dir = get_toothfairy_data(path, download)

images_dir = os.path.join(path, "data", "images")
gt_dir = os.path.join(path, "data", "dense_labels")
if os.path.exists(images_dir) and os.path.exists(gt_dir):
return natsorted(glob(os.path.join(images_dir, "*.nii.gz"))), natsorted(glob(os.path.join(gt_dir, "*.nii.gz")))

os.makedirs(images_dir, exist_ok=True)
os.makedirs(gt_dir, exist_ok=True)

image_paths, gt_paths = [], []
for patient_dir in tqdm(glob(os.path.join(data_dir, "P*"))):
patient_id = os.path.split(patient_dir)[-1]

dense_anns_path = os.path.join(patient_dir, "gt_alpha.npy")
if not os.path.exists(dense_anns_path):
continue

image_path = os.path.join(patient_dir, "data.npy")

image = np.load(image_path)
gt = np.load(dense_anns_path)

image_nifti = nib.Nifti2Image(image, np.eye(4))
gt_nifti = nib.Nifti2Image(gt, np.eye(4))

trg_image_path = os.path.join(images_dir, f"{patient_id}.nii.gz")
trg_gt_path = os.path.join(gt_dir, f"{patient_id}.nii.gz")

nib.save(image_nifti, trg_image_path)
nib.save(gt_nifti, trg_gt_path)

image_paths.append(trg_image_path)
gt_paths.append(trg_gt_path)

return image_paths, gt_paths


def get_toothfairy_dataset(path, patch_shape, download=False, **kwargs):
"""Canal segmentation in CBCT
https://toothfairy.grand-challenge.org/
"""
image_paths, gt_paths = _get_toothfairy_paths(path, download)

dataset = torch_em.default_segmentation_dataset(
raw_paths=image_paths,
raw_key="data",
label_paths=gt_paths,
label_key="data",
is_seg_dataset=True,
patch_shape=patch_shape,
**kwargs
)

return dataset


def get_toothfairy_loader(path, patch_shape, batch_size, download=False, **kwargs):
"""
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_toothfairy_dataset(path, patch_shape, download, **ds_kwargs)
loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
return loader