Skip to content

Commit

Permalink
Add custom splits for CryoNuSeg dataset (#472)
Browse files Browse the repository at this point in the history
Add custom data splits
  • Loading branch information
anwai98 authored Jan 9, 2025
1 parent f412cfc commit 4fa4f6d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
1 change: 1 addition & 0 deletions scripts/datasets/histopathology/check_cryonuseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def check_cryonuseg():
path=os.path.join(ROOT, "cryonuseg"),
patch_shape=(1, 512, 512),
batch_size=1,
split="train",
rater="b1",
download=True,
)
Expand Down
66 changes: 57 additions & 9 deletions torch_em/data/datasets/histopathology/cryonuseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,56 @@
from natsort import natsorted
from typing import Union, Tuple, Literal, List

import json
import pandas as pd
from sklearn.model_selection import train_test_split

from torch.utils.data import Dataset, DataLoader

import torch_em

from .. import util


def get_cryonuseg_data(path: Union[os.PathLike, str], download: bool = False):
def _create_split_csv(path, data_dir, split):
csv_path = os.path.join(path, 'cryonuseg_split.csv')
if os.path.exists(csv_path):
df = pd.read_csv(csv_path)
df[split] = df[split].apply(lambda x: json.loads(x.replace("'", '"'))) # ensures all items from column in list.
split_list = df.iloc[0][split]

else:
print(f"Creating a new split file at '{csv_path}'.")
image_names = [
os.path.basename(image).split(".")[0] for image in glob(os.path.join(path, data_dir, '*.tif'))
]

# Create random splits per dataset.
train_ids, test_ids = train_test_split(image_names, test_size=0.2) # 20% for test split.
train_ids, val_ids = train_test_split(train_ids, test_size=0.15) # 15% for val split.
split_ids = {"train": train_ids, "val": val_ids, "test": test_ids}

df = pd.DataFrame.from_dict([split_ids])
df.to_csv(csv_path, index=False)

split_list = split_ids[split]

return split_list


def get_cryonuseg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
"""Download the CryoNuSeg dataset for nucleus segmentation.
Args:
path: Filepath to a folder where the downloaded data will be saved.
download: Whether to download the data if it is not present.
Returns:
The folder where the data is downloaded and preprocessed.
"""
data_dir = os.path.join(path, r"tissue images")
if os.path.exists(os.path.join(path, r"tissue images")):
return
return data_dir

os.makedirs(path, exist_ok=True)
util.download_source_kaggle(
Expand All @@ -35,22 +69,28 @@ def get_cryonuseg_data(path: Union[os.PathLike, str], download: bool = False):
zip_path = os.path.join(path, "segmentation-of-nuclei-in-cryosectioned-he-images.zip")
util.unzip(zip_path=zip_path, dst=path)

return data_dir


def get_cryonuseg_paths(
path: Union[os.PathLike, str], rater_choice: Literal["b1", "b2", "b3"] = "b1", download: bool = False
path: Union[os.PathLike, str],
split: Literal["train", "val", "test"],
rater_choice: Literal["b1", "b2", "b3"] = "b1",
download: bool = False,
) -> Tuple[List[str], List[str]]:
"""Get paths to the CryoNuSeg data.
Args:
path: Filepath to a folder where the downloaded data will be saved.
split: The choice of data split.
rater: The choice of annotator.
download: Whether to download the data if it is not present.
Returns:
List of filepaths to the image data.
List of filepaths to the label data.
"""
get_cryonuseg_data(path, download)
data_dir = get_cryonuseg_data(path, download)

if rater_choice == "b1":
label_dir = r"Annotator 1 (biologist)/"
Expand All @@ -63,16 +103,21 @@ def get_cryonuseg_paths(

# Point to the instance labels folder
label_dir += r"label masks modify"
split_list = _create_split_csv(path, label_dir, split)

# Get the raw and label paths
label_paths = natsorted([os.path.join(path, label_dir, f'{fname}.tif') for fname in split_list])
raw_paths = natsorted([os.path.join(data_dir, f'{fname}.tif') for fname in split_list])

label_paths = natsorted(glob(os.path.join(path, label_dir, "*.tif")))
raw_paths = natsorted(glob(os.path.join(path, r"tissue images", "*.tif")))
assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0

return raw_paths, label_paths


def get_cryonuseg_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
split: Literal["train", "val", "test"],
rater: Literal["b1", "b2", "b3"] = "b1",
resize_inputs: bool = False,
download: bool = False,
Expand All @@ -83,6 +128,7 @@ def get_cryonuseg_dataset(
Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
split: The choice of data split.
rater: The choice of annotator.
resize_inputs: Whether to resize the inputs.
download: Whether to download the data if it is not present.
Expand All @@ -91,7 +137,7 @@ def get_cryonuseg_dataset(
Returns:
The segmentation dataset.
"""
raw_paths, label_paths = get_cryonuseg_paths(path, rater, download)
raw_paths, label_paths = get_cryonuseg_paths(path, split, rater, download)

if resize_inputs:
resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
Expand All @@ -114,6 +160,7 @@ def get_cryonuseg_loader(
path: Union[os.PathLike, str],
batch_size: int,
patch_shape: Tuple[int, int],
split: Literal["train", "val", "test"],
rater: Literal["b1", "b2", "b3"] = "b1",
resize_inputs: bool = False,
download: bool = False,
Expand All @@ -125,6 +172,7 @@ def get_cryonuseg_loader(
path: Filepath to a folder where the downloaded data will be saved.
batch_size: The batch size for training.
patch_shape: The patch shape to use for training.
split: The choice of data split.
rater: The choice of annotator.
resize_inputs: Whether to resize the inputs.
download: Whether to download the data if it is not present.
Expand All @@ -134,5 +182,5 @@ def get_cryonuseg_loader(
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_cryonuseg_dataset(path, patch_shape, rater, resize_inputs, download, **ds_kwargs)
return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
dataset = get_cryonuseg_dataset(path, patch_shape, split, rater, resize_inputs, download, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

0 comments on commit 4fa4f6d

Please sign in to comment.