diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 3a2872a6388..58ac084339a 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes `_. + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. + partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``. + + .. note:: + + The partition only changes which split each image belongs to. Thus, regardless of the selected + partition, combining all splits will result in all images. + + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" + _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1" + + def __init__( + self, + root: str, + split: str = "train", + partition: int = 1, + download: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + self._split = verify_str_arg(split, "split", ("train", "val", "test")) + if not isinstance(partition, int) and not (1 <= partition <= 10): + raise ValueError( + f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, " + f"but got {partition} instead" + ) + self._partition = partition + + super().__init__(root, transform=transform, target_transform=target_transform) + self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower() + self._data_folder = self._base_folder / "dtd" + self._meta_folder = self._data_folder / "labels" + self._images_folder = self._data_folder / "images" + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._image_files = [] + classes = [] + with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file: + for line in file: + cls, name = line.strip().split("/") + self._image_files.append(self._images_folder.joinpath(cls, name)) + classes.append(cls) + + self.classes = sorted(set(classes)) + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + self._labels = [self.class_to_idx[cls] for cls in classes] + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx): + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def extra_repr(self) -> str: + return f"split={self._split}, partition={self._partition}" + + def _check_exists(self) -> bool: + return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder) + + def _download(self) -> None: + if self._check_exists(): + return + download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5) diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 62abc3119f6..7e5fd788466 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -2,6 +2,7 @@ from .celeba import CelebA from .cifar import Cifar10, Cifar100 from .coco import Coco +from .dtd import DTD from .imagenet import ImageNet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .sbd import SBD diff --git a/torchvision/prototype/datasets/_builtin/dtd.categories b/torchvision/prototype/datasets/_builtin/dtd.categories new file mode 100644 index 00000000000..7f3df8a2b00 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/dtd.categories @@ -0,0 +1,47 @@ +banded +blotchy +braided +bubbly +bumpy +chequered +cobwebbed +cracked +crosshatched +crystalline +dotted +fibrous +flecked +freckled +frilly +gauzy +grid +grooved +honeycombed +interlaced +knitted +lacelike +lined +marbled +matted +meshed +paisley +perforated +pitted +pleated +polka-dotted +porous +potholed +scaly +smeared +spiralled +sprinkled +stained +stratified +striped +studded +swirly +veined +waffled +woven +wrinkled +zigzagged diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py new file mode 100644 index 00000000000..e78ab88da27 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -0,0 +1,130 @@ +import io +import pathlib +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torchdata.datapipes.iter import ( + IterDataPipe, + Mapper, + Shuffler, + Filter, + IterKeyZipper, + Demultiplexer, + LineReader, + CSVParser, +) +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, + DatasetType, +) +from torchvision.prototype.datasets.utils._internal import ( + INFINITE_BUFFER_SIZE, + hint_sharding, + path_comparator, + getitem, +) +from torchvision.prototype.features import Label + + +class DTD(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "dtd", + type=DatasetType.IMAGE, + homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", + valid_options=dict( + split=("train", "test", "val"), + fold=tuple(str(fold) for fold in range(1, 11)), + ), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + archive = HttpResource( + "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", + sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", + decompress=True, + ) + return [archive] + + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: + path = pathlib.Path(data[0]) + if path.parent.name == "labels": + if path.name == "labels_joint_anno.txt": + return 1 + + return 0 + elif path.parents[1].name == "images": + return 2 + else: + return None + + def _image_key_fn(self, data: Tuple[str, Any]) -> str: + path = pathlib.Path(data[0]) + return str(path.relative_to(path.parents[1])) + + def _collate_and_decode_sample( + self, + data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> Dict[str, Any]: + (_, joint_categories_data), image_data = data + _, *joint_categories = joint_categories_data + path, buffer = image_data + + category = pathlib.Path(path).parent.name + + return dict( + joint_categories={category for category in joint_categories if category}, + label=Label(self.info.categories.index(category), category=category), + path=path, + image=decoder(buffer) if decoder else buffer, + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + archive_dp = resource_dps[0] + + splits_dp, joint_categories_dp, images_dp = Demultiplexer( + archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE + ) + + splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt")) + splits_dp = LineReader(splits_dp, decode=True, return_path=False) + splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) + splits_dp = hint_sharding(splits_dp) + + joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ") + + dp = IterKeyZipper( + splits_dp, + joint_categories_dp, + key_fn=getitem(), + ref_key_fn=getitem(0), + buffer_size=INFINITE_BUFFER_SIZE, + ) + dp = IterKeyZipper( + dp, + images_dp, + key_fn=getitem(0), + ref_key_fn=self._image_key_fn, + buffer_size=INFINITE_BUFFER_SIZE, + ) + return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + + def _filter_images(self, data: Tuple[str, Any]) -> bool: + return self._classify_archive(data) == 2 + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) + dp = Filter(dp, self._filter_images) + return sorted({pathlib.Path(path).parent.name for path, _ in dp})