Skip to content

Commit

Permalink
Fix AxonDeepSeg (#196)
Browse files Browse the repository at this point in the history
Fix issue in axon deep seg loader
  • Loading branch information
anwai98 authored Jan 9, 2024
1 parent 1dcbd49 commit 61f2883
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
15 changes: 9 additions & 6 deletions scripts/datasets/check_axondeepseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
from torch_em.util.debug import check_loader


ROOT = "/scratch/usr/nimanwai/data/axondeepseg"


def check_axondeepseg():
loader = get_axondeepseg_loader(
"./data/axondeepseg", name="sem", patch_shape=(1024, 1024), batch_size=1,
one_hot_encoding=True, shuffle=True, download=True
ROOT, name="sem", patch_shape=(1024, 1024), batch_size=1, split="train",
one_hot_encoding=True, shuffle=True, download=True, val_fraction=0.1
)
check_loader(loader, 5)
check_loader(loader, 5, True, True, False, "sem_loader.png")

loader = get_axondeepseg_loader(
"./data/axondeepseg", name="tem", patch_shape=(1024, 1024), batch_size=1,
one_hot_encoding=True, shuffle=True, download=True
ROOT, name="tem", patch_shape=(1024, 1024), batch_size=1, split="train",
one_hot_encoding=True, shuffle=True, download=True, val_fraction=0.1
)
check_loader(loader, 5)
check_loader(loader, 5, True, True, False, "tem_loader.png")


if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions torch_em/data/datasets/axondeepseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _require_axondeepseg_data(path, name, download):


def get_axondeepseg_dataset(
path, name, patch_shape, download=False, one_hot_encoding=False, data_fraction=None, split=None, **kwargs
path, name, patch_shape, download=False, one_hot_encoding=False, val_fraction=None, split=None, **kwargs
):
"""Dataset for the segmentation of myelinated axons in EM.
Expand All @@ -141,10 +141,10 @@ def get_axondeepseg_dataset(
data_root = _require_axondeepseg_data(path, nn, download)
paths = glob(os.path.join(data_root, "*.h5"))
paths.sort()
if data_fraction is not None:
if val_fraction is not None:
assert split is not None
n_samples = int(len(paths) * data_fraction)
paths = paths[:n_samples] if split == "train" else paths[:-n_samples]
n_samples = int(len(paths) * (1 - val_fraction))
paths = paths[:n_samples] if split == "train" else paths[n_samples:]
all_paths.extend(paths)

if one_hot_encoding:
Expand All @@ -171,13 +171,13 @@ def get_axondeepseg_dataset(
def get_axondeepseg_loader(
path, name, patch_shape, batch_size,
download=False, one_hot_encoding=False,
data_fraction=None, split=None, **kwargs
val_fraction=None, split=None, **kwargs
):
"""Dataloader for the segmentation of myelinated axons. See 'get_axondeepseg_dataset' for details.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_axondeepseg_dataset(
path, name, patch_shape, download=download, one_hot_encoding=one_hot_encoding,
data_fraction=data_fraction, split=split, **ds_kwargs
val_fraction=val_fraction, split=split, **ds_kwargs
)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

0 comments on commit 61f2883

Please sign in to comment.