Skip to content

Commit

Permalink
Add data related classes inside shimmer (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Jun 4, 2024
1 parent 46a075e commit 1024766
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 48 deletions.
7 changes: 6 additions & 1 deletion shimmer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from shimmer.dataset import RepeatedDataset
from shimmer.data.dataset import DomainType, RepeatedDataset, ShimmerDataset
from shimmer.data.domain import DataDomain
from shimmer.modules.contrastive_loss import (
ContrastiveLoss,
ContrastiveLossType,
Expand Down Expand Up @@ -99,4 +100,8 @@
"RandomSelection",
"SelectionBase",
"SingleDomainSelection",
"DomainType",
"RepeatedDataset",
"ShimmerDataset",
"DataDomain",
]
9 changes: 9 additions & 0 deletions shimmer/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from shimmer.data.dataset import DomainType, RepeatedDataset, ShimmerDataset
from shimmer.data.domain import DataDomain

__all__ = [
"DomainType",
"RepeatedDataset",
"ShimmerDataset",
"DataDomain",
]
132 changes: 132 additions & 0 deletions shimmer/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from collections.abc import Callable, Mapping
from enum import Enum
from pathlib import Path
from typing import Any, Protocol

from torch.utils.data import Dataset

from shimmer.data.domain import DataDomain


class _SizedDataset(Protocol):
def __getitem__(self, k: int) -> Any: ...

def __len__(self) -> int: ...


class DomainType(Enum):
def __init__(self, base: str, kind: str) -> None:
self.base = base
self.kind = kind


class RepeatedDataset(Dataset):
"""
Dataset that cycles through its items to have a size of at least min size.
If drop_last is True, the size will be exaclty min_size. If drop_last is False,
the min_size ≤ size < min_size + len(dataset).
"""

def __init__(self, dataset: _SizedDataset, min_size: int, drop_last: bool = False):
"""
Args:
dataset (SizedDataset): dataset to repeat. The dataset should have a size
(where `__len__` is defined).
min_size (int): minimum size of the final dataset
drop_last (bool): whether to remove overflow when repeating the
dataset.
"""
self.dataset = dataset
assert min_size >= len(self.dataset)
self.dataset_size = len(self.dataset)
if drop_last:
self.total_size = min_size
else:
self.total_size = (
min_size // self.dataset_size + int(min_size % self.dataset_size > 0)
) * self.dataset_size

def __len__(self) -> int:
"""
Size of the dataset. Will be min_size if drop_last is True.
Otherwise, min_size ≤ size < min_size + len(dataset).
"""
return self.total_size

def __getitem__(self, index: int) -> Any:
return self.dataset[index % self.dataset_size]


class ShimmerDataset(_SizedDataset):
"""
Dataset class to obtain a ShimmerDataset.
"""

def __init__(
self,
dataset_path: str | Path,
split: str,
domain_classes: Mapping[DomainType, type[DataDomain]],
max_size: int = -1,
transforms: Mapping[str, Callable[[Any], Any]] | None = None,
domain_args: Mapping[str, Any] | None = None,
):
"""
Params:
dataset_path (str | pathlib.Path): Path to the dataset.
split (str): Split to use. One of 'train', 'val', 'test'.
domain_classes (Mapping[str, type[SimpleShapesDomain]]): Classes of
domain loaders to include in the dataset.
max_size (int): Max size of the dataset.
transforms (Mapping[str, (Any) -> Any]): Optional transforms to apply
to the domains. The keys are the domain names,
the values are the transforms.
domain_args (Mapping[str, Any]): Optional additional arguments to pass
to the domains.
"""
self.dataset_path = Path(dataset_path)
self.split = split
self.max_size = max_size

self.domains: dict[str, DataDomain] = {}
self.domain_args = domain_args or {}

for domain, domain_cls in domain_classes.items():
transform = None
if transforms is not None and domain.kind in transforms:
transform = transforms[domain.kind]

self.domains[domain.kind] = domain_cls(
dataset_path,
split,
transform,
self.domain_args.get(domain.kind, None),
)

lengths = {len(domain) for domain in self.domains.values()}
assert len(lengths) == 1, "Domains have different lengths"
self.dataset_size = next(iter(lengths))
if self.max_size != -1:
assert (
self.max_size <= self.dataset_size
), "Max sizes can only be lower than actual size."
self.dataset_size = self.max_size

def __len__(self) -> int:
"""
All domains should be the same length.
"""
return self.dataset_size

def __getitem__(self, index: int) -> dict[str, Any]:
"""
Params:
index (int): Index of the item to get.
Returns:
dict[str, Any]: Dictionary containing the domains. The keys are the
domain names, the values are the domains as given by the domain model at
the given index.
"""
return {
domain_name: domain[index] for domain_name, domain in self.domains.items()
}
43 changes: 43 additions & 0 deletions shimmer/data/domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from pathlib import Path
from typing import Any, Generic, TypeVar

# TODO: Consider handling CPU usage
# with a workaround in:
# https://github.com/pytorch/pytorch/issues/13246#issuecomment-905703662


_T = TypeVar("_T")


class DataDomain(ABC, Generic[_T]):
"""
Base class for a domain of the SimpleShapesDataset.
All domains extend this base class and implement the
__getitem__ and __len__ methods.
"""

@abstractmethod
def __init__(
self,
dataset_path: str | Path,
split: str,
transform: Callable[[Any], _T] | None = None,
additional_args: dict[str, Any] | None = None,
) -> None:
"""
Params:
dataset_path (str | pathlib.Path): Path to the dataset.
split (str): The split of the dataset to use. One of "train", "val", "test".
transform (Any -> Any): Optional transform to apply to the data.
additional_args (dict[str, Any]): Optional additional arguments to pass
to the domain.
"""
...

@abstractmethod
def __len__(self) -> int: ...

@abstractmethod
def __getitem__(self, index: int) -> _T: ...
46 changes: 0 additions & 46 deletions shimmer/dataset.py

This file was deleted.

2 changes: 1 addition & 1 deletion shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from shimmer.dataset import RepeatedDataset
from shimmer.data.dataset import RepeatedDataset
from shimmer.modules.contrastive_loss import (
ContrastiveLoss,
ContrastiveLossBayesianType,
Expand Down

0 comments on commit 1024766

Please sign in to comment.