Skip to content

Commit

Permalink
Minor fix to split logic and syntax issues
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jan 13, 2025
1 parent 6ae5eb6 commit 4aa672f
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 40 deletions.
27 changes: 27 additions & 0 deletions scripts/datasets/histopathology/check_glas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import sys

from torch_em.util.debug import check_loader
from torch_em.data.datasets import get_glas_loader


sys.path.append("..")


def check_glas():
from util import ROOT

loader = get_glas_loader(
path=os.path.join(ROOT, "glas"),
batch_size=2,
patch_shape=(512, 512),
split="train",
resize_inputs=True,
download=True,
)

check_loader(loader, 8, rgb=True)


if __name__ == "__main__":
check_glas()
1 change: 1 addition & 0 deletions torch_em/data/datasets/histopathology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .consep import get_consep_loader, get_consep_dataset
from .cpm import get_cpm_loader, get_cpm_dataset
from .cryonuseg import get_cryonuseg_loader, get_cryonuseg_dataset
from .glas import get_glas_loader, get_glas_dataset
from .janowczyk import get_janowczyk_loader, get_janowczyk_dataset
from .lizard import get_lizard_loader, get_lizard_dataset
from .lynsec import get_lynsec_loader, get_lynsec_dataset
Expand Down
98 changes: 58 additions & 40 deletions torch_em/data/datasets/histopathology/glas.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,104 @@
"""
"""

import os
import shutil
from glob import glob
from tqdm import tqdm
from natsort import natsorted
from typing import Union, Tuple, List, Literal

import imageio.v3 as imageio
import numpy as np

import torch_em

from torch.utils.data import Dataset, DataLoader

from .. import util

def _extract_images(split: Literal["train", "test"], data_folder, output_dir):

def _extract_images(split, path):
import h5py
label_paths = natsorted(glob(os.path.join(data_folder, "*anno.bmp")))
image_paths = [image_path for image_path in natsorted(glob(os.path.join(data_folder, "*.bmp")))
if image_path not in label_paths]
os.makedirs(os.path.join(output_dir, split), exist_ok=True)

for image_file in tqdm(image_paths, desc=f"Extract images from {os.path.abspath(data_folder)}"):
fname = os.path.basename(image_file).split(".")[0]
if split not in fname:
continue
label_file = os.path.join(data_folder, f"{fname}_anno.bmp")
assert os.path.exists(label_file), label_file
data_folder = os.path.join(path, "Warwick_QU_Dataset")

image = imageio.imread(image_file)
assert image.ndim == 3 and image.shape[-1] == 3
label_paths = natsorted(glob(os.path.join(data_folder, f"{split}*anno.bmp")))
image_paths = [
image_path for image_path in natsorted(glob(os.path.join(data_folder, f"{split}*.bmp")))
if image_path not in label_paths
]
assert image_paths and len(image_paths) == len(label_paths)

segmentation = imageio.imread(label_file)
assert image.shape[:-1] == segmentation.shape
os.makedirs(os.path.join(path, split), exist_ok=True)

image = image.transpose((2, 0, 1))
assert image.shape[1:] == segmentation.shape
for image_path, label_path in tqdm(
zip(image_paths, label_paths), total=len(image_paths),
desc=f"Extract images from {os.path.abspath(data_folder)}"
):
fname = os.path.basename(image_path).split(".")[0]

output_file = os.path.join(output_dir, split, f"{fname}.h5")
with h5py.File(output_file, "a") as f:
f.create_dataset("image", data=image, compression="gzip")
f.create_dataset("labels/segmentation", data=segmentation, compression="gzip")
image = imageio.imread(image_path)
segmentation = imageio.imread(label_path)
image = image.transpose(2, 0, 1)

def get_glas_data(path: Union[os.PathLike, str], download: bool = False) -> str:
with h5py.File(os.path.join(path, split, f"{fname}.h5"), "a") as f:
f.create_dataset("raw", data=image, compression="gzip")
f.create_dataset("labels", data=segmentation, compression="gzip")


def get_glas_data(
path: Union[os.PathLike, str], split: Literal["train", "val", "test"], download: bool = False
) -> str:
"""Download the GlaS dataset.
Args:
path: Filepath to a folder where the data is downloaded for further processing.
split: The choice of data split.
download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded and preprocessed.
"""
data_dir = os.path.join(path, "data", "Warwick_QU_Dataset")
data_dir = os.path.join(path, split)
if os.path.exists(data_dir):
return data_dir

os.makedirs(path, exist_ok=True)

# Download the files.
util.download_source_kaggle(path=path, dataset_name="sani84/glasmiccai2015-gland-segmentation", download=download)
util.unzip(zip_path=os.path.join(
path, "glasmiccai2015-gland-segmentation.zip"), dst=os.path.join(path, "data"), remove=False
)
os.remove(os.path.join(path, "glasmiccai2015-gland-segmentation.zip"))
util.unzip(zip_path=os.path.join(path, "glasmiccai2015-gland-segmentation.zip"), dst=path)

# Preprocess inputs per split.
splits = ["train", "test"]
if split not in splits:
raise ValueError(f"'{split}' is not a valid split choice.")

for _split in splits:
_extract_images(_split, path)

# Remove original data
shutil.rmtree(os.path.join(path, "Warwick_QU_Dataset"))

return data_dir

def get_glas_paths(
path: Union[os.PathLike], split: Literal["train", "val", "test"], download: bool = False
) -> List[str]:

def get_glas_paths(path: Union[os.PathLike], split: Literal["train", "test"], download: bool = False) -> List[str]:
"""Get paths to the GlaS data.
Args:
path: Filepath to a folder where the downloaded data will be saved.
split: The choice of data splits.
split: The choice of data split.
download: Whether to download the data if it is not present.
Returns:
List of filepaths for the stored data.
"""
data_dir = get_glas_data(path, download)
if not os.path.exists(os.path.join(path, split)):
_extract_images(split, data_dir, path)
data_paths = natsorted(glob(os.path.join(path, split, "*.h5")))
data_dir = get_glas_data(path, split, download)
data_paths = natsorted(glob(os.path.join(data_dir, "*.h5")))
return data_paths


def get_glas_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, int],
Expand Down Expand Up @@ -112,9 +130,9 @@ def get_glas_dataset(

return torch_em.default_segmentation_dataset(
raw_paths=data_paths,
raw_key="image",
raw_key="raw",
label_paths=data_paths,
label_key="labels/segmentation",
label_key="labels",
patch_shape=patch_shape,
ndim=2,
with_channels=True,
Expand Down Expand Up @@ -146,5 +164,5 @@ def get_glas_loader(
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
ds = get_glas_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
return torch_em.get_data_loader(ds, batch_size, **loader_kwargs)
dataset = get_glas_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

0 comments on commit 4aa672f

Please sign in to comment.