Skip to content

Commit

Permalink
Ad explicit splits for histopathology datasets (#444)
Browse files Browse the repository at this point in the history
Ensure explicit data splits for all datasets 

---------

Co-authored-by: titusgriebel <[email protected]>
  • Loading branch information
anwai98 and titusgriebel authored Dec 18, 2024
1 parent 274e948 commit 275f75b
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 51 deletions.
3 changes: 2 additions & 1 deletion scripts/datasets/histopathology/check_cpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def check_cpm():
path=os.path.join(ROOT, "cpm"),
patch_shape=(512, 512),
batch_size=2,
data_choice="cpm15",
data_choice="cpm17",
split="train",
)
check_loader(loader, 8, rgb=True, instance_labels=True)

Expand Down
3 changes: 2 additions & 1 deletion scripts/datasets/histopathology/check_janowczyk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def check_janowczyk():
path=os.path.join(ROOT, "janowczyk"),
patch_shape=(512, 512),
annotation="nuclei",
split="train",
batch_size=2,
download=True,
)

check_loader(loader, 8, instance_labels=True)
check_loader(loader, 8, instance_labels=True, rgb=True)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion scripts/datasets/histopathology/check_puma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def check_puma():
path=os.path.join(ROOT, "puma"),
patch_shape=(512, 512),
batch_size=2,
split="test",
annotations="nuclei",
download=True,
)

check_loader(loader, 8)
check_loader(loader, 8, instance_labels=True, rgb=True)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion scripts/datasets/histopathology/check_tnbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ def check_tnbc():
batch_size=1,
ndim=2,
download=True,
split="train",
)

check_loader(loader, 8, instance_labels=True)
check_loader(loader, 8, instance_labels=True, rgb=True)


if __name__ == "__main__":
Expand Down
63 changes: 50 additions & 13 deletions torch_em/data/datasets/histopathology/cpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
NOTE: You must download the files manually.
1. The dataset is located at https://drive.google.com/drive/folders/1l55cv3DuY-f7-JotDN7N5nbNnjbLWchK.
2. The restructuring details are mentioned by the authors here: https://github.com/vqdang/hover_net/issues/5#issuecomment-508431862. # noqa
2. The restructuring details are mentioned by the authors here: https://github.com/vqdang/hover_net/issues/5#issuecomment-508431862.
This dataset is from the publication https://doi.org/10.3389/fbioe.2019.00053.
Please cite it if you use this dataset for your research.
"""
""" # noqa

import os
from glob import glob
from tqdm import tqdm
from natsort import natsorted
from typing import Union, Literal, Optional, Tuple, List

import json
import pandas as pd
from scipy.io import loadmat
import imageio.v3 as imageio
from sklearn.model_selection import train_test_split

from torch.utils.data import Dataset, DataLoader

Expand All @@ -31,9 +34,32 @@
}


def get_cpm_data(
path: Union[os.PathLike, str], data_choice: Literal['cpm15', 'cpm17'], download: bool = False
) -> str:
def _create_split_csv(path, split):
csv_path = os.path.join(path, 'cpm15_split.csv')
if os.path.exists(csv_path):
df = pd.read_csv(csv_path)
df[split] = df[split].apply(lambda x: json.loads(x.replace("'", '"'))) # ensures all items from column in list.
split_list = df.iloc[0][split]

else:
image_names = [
os.path.basename(image).split(".")[0] for image in glob(os.path.join(path, 'cpm15', 'Images', '*.png'))
]

train_ids, test_ids = train_test_split(image_names, test_size=0.25) # 20% split for test.
train_ids, val_ids = train_test_split(train_ids, test_size=0.20) # 15% split for val.
split_ids = {"train": train_ids, "val": val_ids, "test": test_ids}

print(len(train_ids), len(val_ids), len(test_ids))

df = pd.DataFrame.from_dict([split_ids])
df.to_csv(csv_path)
split_list = split_ids[split]

return split_list


def get_cpm_data(path: Union[os.PathLike, str], data_choice: Literal['cpm15', 'cpm17'], download: bool = False) -> str:
"""Obtain the CPM data.
NOTE: The dataset is located at https://drive.google.com/drive/folders/1l55cv3DuY-f7-JotDN7N5nbNnjbLWchK.
Expand Down Expand Up @@ -74,13 +100,17 @@ def get_cpm_data(


def get_cpm_paths(
path: Union[os.PathLike, str], data_choice: Literal['cpm15', 'cpm17'], download: bool = False
path: Union[os.PathLike, str],
data_choice: Literal['cpm15', 'cpm17'],
split: Literal["train", "val", "test"],
download: bool = False
) -> Tuple[List[str], List[str]]:
"""Get paths to the CPM data.
Args:
path: Filepath to a folder where the data is downloaded for further processing.
data_choice: The choice of data.
split: The choice of data split.
download: Whether to download the data if it is not present.
Returns:
Expand All @@ -91,11 +121,16 @@ def get_cpm_paths(

if data_choice == "cpm15":
raw_dir, label_dir = "Images", "Labels"
else:
raw_dir, label_dir = "*/Images", "*/Labels"
split_list = _create_split_csv(path, split)

raw_paths = [os.path.join(data_dir, raw_dir, f"{fname}.png") for fname in split_list]
label_mat_paths = [os.path.join(data_dir, label_dir, f"{fname}.mat") for fname in split_list]

raw_paths = [p for p in natsorted(glob(os.path.join(data_dir, raw_dir, "*.png")))]
label_mat_paths = [p for p in natsorted(glob(os.path.join(data_dir, label_dir, "*.mat")))]
else:
assert split in ['train', 'test'], 'Explicit val split does not exist for cpm17.'
raw_dir, label_dir = f"{split}/Images", f"{split}/Labels"
raw_paths = [p for p in natsorted(glob(os.path.join(data_dir, raw_dir, "*.png")))]
label_mat_paths = [p for p in natsorted(glob(os.path.join(data_dir, label_dir, "*.mat")))]

label_paths = []
for mpath in tqdm(label_mat_paths, desc="Preprocessing labels"):
Expand All @@ -107,7 +142,7 @@ def get_cpm_paths(
label = loadmat(mpath)["inst_map"]
imageio.imwrite(label_path, label, compression="zlib")

assert len(raw_paths) == len(label_paths)
assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0

return raw_paths, label_paths

Expand All @@ -116,6 +151,7 @@ def get_cpm_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
data_choice: Optional[Literal['cpm15', 'cpm17']] = None,
split: Literal["train", "val", "test"] = None,
download: bool = False,
**kwargs
) -> Dataset:
Expand All @@ -131,7 +167,7 @@ def get_cpm_dataset(
Returns:
The segmentation dataset.
"""
raw_paths, label_paths = get_cpm_paths(path, data_choice, download)
raw_paths, label_paths = get_cpm_paths(path, data_choice, split, download)

return torch_em.default_segmentation_dataset(
raw_paths=raw_paths,
Expand All @@ -151,6 +187,7 @@ def get_cpm_loader(
batch_size: int,
patch_shape: Tuple[int, int],
data_choice: Optional[Literal['cpm15', 'cpm17']] = None,
split: Literal["train", "val", "test"] = None,
download: bool = False,
**kwargs
) -> DataLoader:
Expand All @@ -168,5 +205,5 @@ def get_cpm_loader(
The DataLoader
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_cpm_dataset(path, patch_shape, data_choice, download, **ds_kwargs)
dataset = get_cpm_dataset(path, patch_shape, data_choice, split, download, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
77 changes: 62 additions & 15 deletions torch_em/data/datasets/histopathology/janowczyk.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
from glob import glob
from tqdm import tqdm
from natsort import natsorted
from typing import Union, Tuple, Literal, List
from typing import Union, Tuple, Literal, List, Optional

import json
import pandas as pd
import imageio.v3 as imageio
from sklearn.model_selection import train_test_split
from skimage.measure import label as connected_components

from torch.utils.data import Dataset, DataLoader
Expand All @@ -39,6 +42,31 @@
}


def _create_split_csv(path, split):
"Create splits on patient level data."
csv_path = os.path.join(path, 'janowczyk_split.csv')
if os.path.exists(csv_path):
df = pd.read_csv(csv_path)
df[split] = df[split].apply(lambda x: json.loads(x.replace("'", '"'))) # ensures all items from column in list.
split_list = df.iloc[0][split]

else:
patient_ids = [
os.path.basename(image).split("_original")[0]
for image in glob(os.path.join(path, 'data', 'nuclei', '*original.tif'))
]

train_ids, test_ids = train_test_split(patient_ids, test_size=0.2) # 20% for test split.
train_ids, val_ids = train_test_split(train_ids, test_size=0.15) # 15% for train split.

split_ids = {"train": train_ids, "val": val_ids, "test": test_ids}
df = pd.DataFrame.from_dict([split_ids])
df.to_csv(csv_path)
split_list = split_ids[split]

return split_list


def get_janowczyk_data(
path: Union[os.PathLike, str],
annotation: Literal['nuclei', 'epithelium', 'tubule'] = "nuclei",
Expand Down Expand Up @@ -74,13 +102,15 @@ def get_janowczyk_data(

def get_janowczyk_paths(
path: Union[os.PathLike, str],
split: Optional[Literal["train", "val", "test"]] = None,
annotation: Literal['nuclei', 'epithelium', 'tubule'] = "nuclei",
download: bool = False
) -> Tuple[List[str], List[str]]:
"""Get paths to the Janowczyk data.
Args:
path: Filepath to a folder where the downloaded data will be saved.
split: The choice of data split.
annotation: The choice of annotated labels.
download: Whether to download the data if it is not present.
Expand All @@ -90,32 +120,46 @@ def get_janowczyk_paths(
"""
data_dir = get_janowczyk_data(path, annotation, download)

if annotation == "epithelium":
label_paths = natsorted(glob(os.path.join(data_dir, "masks", "*_mask.png")))
raw_paths = [p.replace("masks/", "").replace("_mask.png", ".tif") for p in label_paths]
elif annotation == "tubule":
label_paths = natsorted(glob(os.path.join(data_dir, "*_anno.bmp")))
raw_paths = [p.replace("_anno", "") for p in label_paths]
else: # nuclei
raw_paths = natsorted(glob(os.path.join(data_dir, "*_original.tif")))
label_paths = []
for lpath in tqdm(glob(os.path.join(data_dir, "*_mask.png")), desc="Preprocessing 'nuclei' labels"):
if annotation == "nuclei":
split_list = _create_split_csv(path, split)

raw_paths = [os.path.join(data_dir, f"{name}_original.tif") for name in split_list]
label_paths = [os.path.join(data_dir, f"{name}_mask.png") for name in split_list]

neu_label_paths = []
for lpath in tqdm(label_paths, desc="Preprocessing 'nuclei' labels"):
neu_label_path = lpath.replace("_mask.png", "_preprocessed_labels.tif")
label_paths.append(neu_label_path)
neu_label_paths.append(neu_label_path)
if os.path.exists(neu_label_path):
continue

label = imageio.imread(lpath)
label = connected_components(label)
label = connected_components(label) # run coonected components on all nuclei instances.
imageio.imwrite(neu_label_path, label, compression="zlib")

label_paths = natsorted(label_paths)
raw_paths = natsorted(raw_paths)

else:
assert split is None, "No other dataset besides 'nuclei' has splits at the moment."

if annotation == "epithelium":
label_paths = natsorted(glob(os.path.join(data_dir, "masks", "*_mask.png")))
raw_paths = [p.replace("masks/", "").replace("_mask.png", ".tif") for p in label_paths]

else: # tubule
label_paths = natsorted(glob(os.path.join(data_dir, "*_anno.bmp")))
raw_paths = [p.replace("_anno", "") for p in label_paths]

assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0

return raw_paths, label_paths


def get_janowczyk_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
split: Optional[Literal["train", "val", "test"]] = None,
annotation: Literal['nuclei', 'epithelium', 'tubule'] = "nuclei",
download: bool = False,
**kwargs
Expand All @@ -125,14 +169,15 @@ def get_janowczyk_dataset(
Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
split: The choice of data split.
annotation: The choice of annotated labels.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
Returns:
The segmentation dataset.
"""
raw_paths, label_paths = get_janowczyk_paths(path, annotation, download)
raw_paths, label_paths = get_janowczyk_paths(path, split, annotation, download)

return torch_em.default_segmentation_dataset(
raw_paths=raw_paths,
Expand All @@ -151,6 +196,7 @@ def get_janowczyk_loader(
path: Union[os.PathLike, str],
batch_size: int,
patch_shape: Tuple[int, int],
split: Optional[Literal["train", "val", "test"]] = None,
annotation: Literal['nuclei', 'epithelium', 'tubule'] = "nuclei",
download: bool = False,
**kwargs
Expand All @@ -161,6 +207,7 @@ def get_janowczyk_loader(
path: Filepath to a folder where the downloaded data will be saved.
batch_size: The batch size for training.
patch_shape: The patch shape to use for training.
split: The choice of data split/
annotation: The choice of annotated labels.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
Expand All @@ -169,5 +216,5 @@ def get_janowczyk_loader(
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_janowczyk_dataset(path, patch_shape, annotation, download, **ds_kwargs)
dataset = get_janowczyk_dataset(path, patch_shape, split, annotation, download, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Loading

0 comments on commit 275f75b

Please sign in to comment.