From a2c033c804538ce3837d0cfc1df69dd68adb585e Mon Sep 17 00:00:00 2001 From: Philip Meier <github.pmeier@posteo.de> Date: Mon, 6 Mar 2023 12:06:09 +0100 Subject: [PATCH] Remove prototype code / tests / CI on release branch (#7390) --- .../workflows/prototype-tests-linux-gpu.yml | 94 - test/builtin_dataset_mocks.py | 1582 ----------------- test/prototype_common_utils.py | 82 - test/test_prototype_datapoints.py | 133 -- test/test_prototype_datasets_builtin.py | 282 --- test/test_prototype_datasets_utils.py | 302 ---- test/test_prototype_models.py | 84 - test/test_prototype_transforms.py | 535 ------ torchvision/prototype/__init__.py | 1 - torchvision/prototype/datapoints/__init__.py | 1 - torchvision/prototype/datapoints/_label.py | 78 - torchvision/prototype/datasets/__init__.py | 15 - torchvision/prototype/datasets/_api.py | 65 - .../prototype/datasets/_builtin/README.md | 340 ---- .../prototype/datasets/_builtin/__init__.py | 22 - .../prototype/datasets/_builtin/caltech.py | 212 --- .../datasets/_builtin/caltech101.categories | 101 -- .../datasets/_builtin/caltech256.categories | 257 --- .../prototype/datasets/_builtin/celeba.py | 200 --- .../prototype/datasets/_builtin/cifar.py | 142 -- .../datasets/_builtin/cifar10.categories | 10 - .../datasets/_builtin/cifar100.categories | 100 -- .../prototype/datasets/_builtin/clevr.py | 107 -- .../datasets/_builtin/coco.categories | 91 - .../prototype/datasets/_builtin/coco.py | 274 --- .../datasets/_builtin/country211.categories | 211 --- .../prototype/datasets/_builtin/country211.py | 81 - .../datasets/_builtin/cub200.categories | 200 --- .../prototype/datasets/_builtin/cub200.py | 265 --- .../datasets/_builtin/dtd.categories | 47 - .../prototype/datasets/_builtin/dtd.py | 139 -- .../prototype/datasets/_builtin/eurosat.py | 66 - .../prototype/datasets/_builtin/fer2013.py | 64 - .../datasets/_builtin/food101.categories | 101 -- .../prototype/datasets/_builtin/food101.py | 97 - .../prototype/datasets/_builtin/gtsrb.py | 112 -- .../datasets/_builtin/imagenet.categories | 1000 ----------- .../prototype/datasets/_builtin/imagenet.py | 223 --- .../prototype/datasets/_builtin/mnist.py | 419 ----- .../_builtin/oxford-iiit-pet.categories | 37 - .../datasets/_builtin/oxford_iiit_pet.py | 146 -- .../prototype/datasets/_builtin/pcam.py | 129 -- .../datasets/_builtin/sbd.categories | 20 - .../prototype/datasets/_builtin/sbd.py | 165 -- .../prototype/datasets/_builtin/semeion.py | 55 - .../_builtin/stanford-cars.categories | 196 -- .../datasets/_builtin/stanford_cars.py | 117 -- .../prototype/datasets/_builtin/svhn.py | 84 - .../prototype/datasets/_builtin/usps.py | 70 - .../datasets/_builtin/voc.categories | 21 - .../prototype/datasets/_builtin/voc.py | 222 --- torchvision/prototype/datasets/_folder.py | 66 - torchvision/prototype/datasets/_home.py | 28 - torchvision/prototype/datasets/benchmark.py | 661 ------- .../datasets/generate_category_files.py | 61 - .../prototype/datasets/utils/__init__.py | 4 - .../prototype/datasets/utils/_dataset.py | 57 - .../prototype/datasets/utils/_encoded.py | 57 - .../prototype/datasets/utils/_internal.py | 194 -- .../prototype/datasets/utils/_resource.py | 235 --- torchvision/prototype/models/__init__.py | 1 - .../prototype/models/depth/__init__.py | 1 - .../prototype/models/depth/stereo/__init__.py | 2 - .../models/depth/stereo/crestereo.py | 1463 --------------- .../models/depth/stereo/raft_stereo.py | 838 --------- torchvision/prototype/transforms/__init__.py | 6 - torchvision/prototype/transforms/_augment.py | 300 ---- torchvision/prototype/transforms/_geometry.py | 125 -- torchvision/prototype/transforms/_misc.py | 58 - torchvision/prototype/transforms/_presets.py | 80 - .../prototype/transforms/_type_conversion.py | 29 - torchvision/prototype/utils/__init__.py | 1 - torchvision/prototype/utils/_internal.py | 126 -- 73 files changed, 13790 deletions(-) delete mode 100644 .github/workflows/prototype-tests-linux-gpu.yml delete mode 100644 test/builtin_dataset_mocks.py delete mode 100644 test/prototype_common_utils.py delete mode 100644 test/test_prototype_datapoints.py delete mode 100644 test/test_prototype_datasets_builtin.py delete mode 100644 test/test_prototype_datasets_utils.py delete mode 100644 test/test_prototype_models.py delete mode 100644 test/test_prototype_transforms.py delete mode 100644 torchvision/prototype/__init__.py delete mode 100644 torchvision/prototype/datapoints/__init__.py delete mode 100644 torchvision/prototype/datapoints/_label.py delete mode 100644 torchvision/prototype/datasets/__init__.py delete mode 100644 torchvision/prototype/datasets/_api.py delete mode 100644 torchvision/prototype/datasets/_builtin/README.md delete mode 100644 torchvision/prototype/datasets/_builtin/__init__.py delete mode 100644 torchvision/prototype/datasets/_builtin/caltech.py delete mode 100644 torchvision/prototype/datasets/_builtin/caltech101.categories delete mode 100644 torchvision/prototype/datasets/_builtin/caltech256.categories delete mode 100644 torchvision/prototype/datasets/_builtin/celeba.py delete mode 100644 torchvision/prototype/datasets/_builtin/cifar.py delete mode 100644 torchvision/prototype/datasets/_builtin/cifar10.categories delete mode 100644 torchvision/prototype/datasets/_builtin/cifar100.categories delete mode 100644 torchvision/prototype/datasets/_builtin/clevr.py delete mode 100644 torchvision/prototype/datasets/_builtin/coco.categories delete mode 100644 torchvision/prototype/datasets/_builtin/coco.py delete mode 100644 torchvision/prototype/datasets/_builtin/country211.categories delete mode 100644 torchvision/prototype/datasets/_builtin/country211.py delete mode 100644 torchvision/prototype/datasets/_builtin/cub200.categories delete mode 100644 torchvision/prototype/datasets/_builtin/cub200.py delete mode 100644 torchvision/prototype/datasets/_builtin/dtd.categories delete mode 100644 torchvision/prototype/datasets/_builtin/dtd.py delete mode 100644 torchvision/prototype/datasets/_builtin/eurosat.py delete mode 100644 torchvision/prototype/datasets/_builtin/fer2013.py delete mode 100644 torchvision/prototype/datasets/_builtin/food101.categories delete mode 100644 torchvision/prototype/datasets/_builtin/food101.py delete mode 100644 torchvision/prototype/datasets/_builtin/gtsrb.py delete mode 100644 torchvision/prototype/datasets/_builtin/imagenet.categories delete mode 100644 torchvision/prototype/datasets/_builtin/imagenet.py delete mode 100644 torchvision/prototype/datasets/_builtin/mnist.py delete mode 100644 torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories delete mode 100644 torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py delete mode 100644 torchvision/prototype/datasets/_builtin/pcam.py delete mode 100644 torchvision/prototype/datasets/_builtin/sbd.categories delete mode 100644 torchvision/prototype/datasets/_builtin/sbd.py delete mode 100644 torchvision/prototype/datasets/_builtin/semeion.py delete mode 100644 torchvision/prototype/datasets/_builtin/stanford-cars.categories delete mode 100644 torchvision/prototype/datasets/_builtin/stanford_cars.py delete mode 100644 torchvision/prototype/datasets/_builtin/svhn.py delete mode 100644 torchvision/prototype/datasets/_builtin/usps.py delete mode 100644 torchvision/prototype/datasets/_builtin/voc.categories delete mode 100644 torchvision/prototype/datasets/_builtin/voc.py delete mode 100644 torchvision/prototype/datasets/_folder.py delete mode 100644 torchvision/prototype/datasets/_home.py delete mode 100644 torchvision/prototype/datasets/benchmark.py delete mode 100644 torchvision/prototype/datasets/generate_category_files.py delete mode 100644 torchvision/prototype/datasets/utils/__init__.py delete mode 100644 torchvision/prototype/datasets/utils/_dataset.py delete mode 100644 torchvision/prototype/datasets/utils/_encoded.py delete mode 100644 torchvision/prototype/datasets/utils/_internal.py delete mode 100644 torchvision/prototype/datasets/utils/_resource.py delete mode 100644 torchvision/prototype/models/__init__.py delete mode 100644 torchvision/prototype/models/depth/__init__.py delete mode 100644 torchvision/prototype/models/depth/stereo/__init__.py delete mode 100644 torchvision/prototype/models/depth/stereo/crestereo.py delete mode 100644 torchvision/prototype/models/depth/stereo/raft_stereo.py delete mode 100644 torchvision/prototype/transforms/__init__.py delete mode 100644 torchvision/prototype/transforms/_augment.py delete mode 100644 torchvision/prototype/transforms/_geometry.py delete mode 100644 torchvision/prototype/transforms/_misc.py delete mode 100644 torchvision/prototype/transforms/_presets.py delete mode 100644 torchvision/prototype/transforms/_type_conversion.py delete mode 100644 torchvision/prototype/utils/__init__.py delete mode 100644 torchvision/prototype/utils/_internal.py diff --git a/.github/workflows/prototype-tests-linux-gpu.yml b/.github/workflows/prototype-tests-linux-gpu.yml deleted file mode 100644 index 76e6b71b7b9..00000000000 --- a/.github/workflows/prototype-tests-linux-gpu.yml +++ /dev/null @@ -1,94 +0,0 @@ -name: Prototype tests on Linux - -on: - pull_request: - -jobs: - tests: - strategy: - matrix: - python-version: - - "3.8" - - "3.9" - - "3.10" - gpu-arch-type: ["cpu"] - gpu-arch-version: [""] - runner: ["linux.2xlarge"] - include: - - python-version: "3.8" - gpu-arch-type: cuda - gpu-arch-version: "11.7" - runner: linux.4xlarge.nvidia.gpu - fail-fast: false - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main - with: - job-name: Python ${{ matrix.python-version }}, ${{ matrix.gpu-arch-type }} - repository: pytorch/vision - gpu-arch-type: ${{ matrix.gpu-arch-type }} - gpu-arch-version: ${{ matrix.gpu-arch-version }} - runner: ${{ matrix.runner }} - timeout: 45 - script: | - # Mark Build Directory Safe - - echo '::group::Set PyTorch conda channel' - if [[ (${GITHUB_BASE_REF} = 'release'*) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then - POSTFIX=test - else - POSTFIX=nightly - fi - PYTORCH_CHANNEL=pytorch-"${POSTFIX}" - echo "${PYTORCH_CHANNEL}" - echo '::endgroup::' - - echo '::group::Set PyTorch conda mutex' - if [[ ${{ matrix.gpu-arch-type }} = 'cuda' ]]; then - PYTORCH_MUTEX="pytorch-cuda=${{ matrix.gpu-arch-version }}" - else - PYTORCH_MUTEX=cpuonly - fi - echo "${PYTORCH_MUTEX}" - echo '::endgroup::' - - echo '::group::Create conda environment' - conda create --prefix $PWD/ci \ - --quiet --yes \ - python=${{ matrix.python-version }} \ - numpy libpng jpeg scipy - conda activate $PWD/ci - echo '::endgroup::' - - echo '::group::Install PyTorch' - conda install \ - --quiet --yes \ - -c "${PYTORCH_CHANNEL}" \ - -c nvidia \ - pytorch \ - "${PYTORCH_MUTEX}" - if [[ ${{ matrix.gpu-arch-type }} = 'cuda' ]]; then - python3 -c "import torch; exit(not torch.cuda.is_available())" - fi - echo '::endgroup::' - - echo '::group::Install TorchVision' - python setup.py develop - echo '::endgroup::' - - echo '::group::Collect PyTorch environment information' - python -m torch.utils.collect_env - echo '::endgroup::' - - echo '::group::Install testing utilities' - pip install --progress-bar=off pytest pytest-mock pytest-cov - echo '::endgroup::' - - echo '::group::Run prototype tests' - # We don't want to run the prototype datasets tests. Since the positional glob into `pytest`, i.e. - # `test/test_prototype*.py` takes the highest priority, neither `--ignore` nor `--ignore-glob` can help us here. - rm test/test_prototype_datasets*.py - pytest \ - --durations=25 \ - --cov=torchvision/prototype \ - --cov-report=term-missing \ - test/test_prototype*.py - echo '::endgroup::' diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py deleted file mode 100644 index ef5d5e1ec96..00000000000 --- a/test/builtin_dataset_mocks.py +++ /dev/null @@ -1,1582 +0,0 @@ -import bz2 -import collections.abc -import csv -import functools -import gzip -import io -import itertools -import json -import lzma -import pathlib -import pickle -import random -import shutil -import unittest.mock -import xml.etree.ElementTree as ET -from collections import Counter, defaultdict - -import numpy as np -import pytest -import torch -from common_utils import combinations_grid -from datasets_utils import create_image_file, create_image_folder, make_tar, make_zip -from torch.nn.functional import one_hot -from torch.testing import make_tensor as _make_tensor -from torchvision.prototype import datasets - -make_tensor = functools.partial(_make_tensor, device="cpu") -make_scalar = functools.partial(make_tensor, ()) - - -__all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"] - - -class DatasetMock: - def __init__(self, name, *, mock_data_fn, configs): - # FIXME: error handling for unknown names - self.name = name - self.mock_data_fn = mock_data_fn - self.configs = configs - - def _parse_mock_info(self, mock_info): - if mock_info is None: - raise pytest.UsageError( - f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an " - f"integer indicating the number of samples for the current `config`." - ) - elif isinstance(mock_info, int): - mock_info = dict(num_samples=mock_info) - elif not isinstance(mock_info, dict): - raise pytest.UsageError( - f"The mock data function for dataset '{self.name}' returned a {type(mock_info)}. The returned object " - f"should be a dictionary containing at least the number of samples for the key `'num_samples'`. If no " - f"additional information is required for specific tests, the number of samples can also be returned as " - f"an integer." - ) - elif "num_samples" not in mock_info: - raise pytest.UsageError( - f"The dictionary returned by the mock data function for dataset '{self.name}' has to contain a " - f"`'num_samples'` entry indicating the number of samples." - ) - - return mock_info - - def load(self, config): - # `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in - # test/test_prototype_builtin_datasets.py - root = pathlib.Path(datasets.home()) / self.name - # We cannot place the mock data upfront in `root`. Loading a dataset calls `OnlineResource.load`. In turn, - # this will only download **and** preprocess if the file is not present. In other words, if we already place - # the file in `root` before the resource is loaded, we are effectively skipping the preprocessing. - # To avoid that we first place the mock data in a temporary directory and patch the download logic to move it to - # `root` only when it is requested. - tmp_mock_data_folder = root / "__mock__" - tmp_mock_data_folder.mkdir(parents=True) - - mock_info = self._parse_mock_info(self.mock_data_fn(tmp_mock_data_folder, config)) - - def patched_download(resource, root, **kwargs): - src = tmp_mock_data_folder / resource.file_name - if not src.exists(): - raise pytest.UsageError( - f"Dataset '{self.name}' requires the file {resource.file_name} for {config}" - f"but it was not created by the mock data function." - ) - - dst = root / resource.file_name - shutil.move(str(src), str(root)) - - return dst - - with unittest.mock.patch( - "torchvision.prototype.datasets.utils._resource.OnlineResource.download", new=patched_download - ): - dataset = datasets.load(self.name, **config) - - extra_files = list(tmp_mock_data_folder.glob("**/*")) - if extra_files: - raise pytest.UsageError( - ( - f"Dataset '{self.name}' created the following files for {config} in the mock data function, " - f"but they were not loaded:\n\n" - ) - + "\n".join(str(file.relative_to(tmp_mock_data_folder)) for file in extra_files) - ) - - tmp_mock_data_folder.rmdir() - - return dataset, mock_info - - -def config_id(name, config): - parts = [name] - for name, value in config.items(): - if isinstance(value, bool): - part = ("" if value else "no_") + name - else: - part = str(value) - parts.append(part) - return "-".join(parts) - - -def parametrize_dataset_mocks(*dataset_mocks, marks=None): - mocks = {} - for mock in dataset_mocks: - if isinstance(mock, DatasetMock): - mocks[mock.name] = mock - elif isinstance(mock, collections.abc.Mapping): - mocks.update(mock) - else: - raise pytest.UsageError( - f"The positional arguments passed to `parametrize_dataset_mocks` can either be a `DatasetMock`, " - f"a sequence of `DatasetMock`'s, or a mapping of names to `DatasetMock`'s, " - f"but got {mock} instead." - ) - dataset_mocks = mocks - - if marks is None: - marks = {} - elif not isinstance(marks, collections.abc.Mapping): - raise pytest.UsageError() - - return pytest.mark.parametrize( - ("dataset_mock", "config"), - [ - pytest.param(dataset_mock, config, id=config_id(name, config), marks=marks.get(name, ())) - for name, dataset_mock in dataset_mocks.items() - for config in dataset_mock.configs - ], - ) - - -DATASET_MOCKS = {} - - -def register_mock(name=None, *, configs): - def wrapper(mock_data_fn): - nonlocal name - if name is None: - name = mock_data_fn.__name__ - DATASET_MOCKS[name] = DatasetMock(name, mock_data_fn=mock_data_fn, configs=configs) - - return mock_data_fn - - return wrapper - - -class MNISTMockData: - _DTYPES_ID = { - torch.uint8: 8, - torch.int8: 9, - torch.int16: 11, - torch.int32: 12, - torch.float32: 13, - torch.float64: 14, - } - - @classmethod - def _magic(cls, dtype, ndim): - return cls._DTYPES_ID[dtype] * 256 + ndim + 1 - - @staticmethod - def _encode(t): - return torch.tensor(t, dtype=torch.int32).numpy().tobytes()[::-1] - - @staticmethod - def _big_endian_dtype(dtype): - np_dtype = getattr(np, str(dtype).replace("torch.", ""))().dtype - return np.dtype(f">{np_dtype.kind}{np_dtype.itemsize}") - - @classmethod - def _create_binary_file(cls, root, filename, *, num_samples, shape, dtype, compressor, low=0, high): - with compressor(root / filename, "wb") as fh: - for meta in (cls._magic(dtype, len(shape)), num_samples, *shape): - fh.write(cls._encode(meta)) - - data = make_tensor((num_samples, *shape), dtype=dtype, low=low, high=high) - - fh.write(data.numpy().astype(cls._big_endian_dtype(dtype)).tobytes()) - - @classmethod - def generate( - cls, - root, - *, - num_categories, - num_samples=None, - images_file, - labels_file, - image_size=(28, 28), - image_dtype=torch.uint8, - label_size=(), - label_dtype=torch.uint8, - compressor=None, - ): - if num_samples is None: - num_samples = num_categories - if compressor is None: - compressor = gzip.open - - cls._create_binary_file( - root, - images_file, - num_samples=num_samples, - shape=image_size, - dtype=image_dtype, - compressor=compressor, - high=float("inf"), - ) - cls._create_binary_file( - root, - labels_file, - num_samples=num_samples, - shape=label_size, - dtype=label_dtype, - compressor=compressor, - high=num_categories, - ) - - return num_samples - - -def mnist(root, config): - prefix = "train" if config["split"] == "train" else "t10k" - return MNISTMockData.generate( - root, - num_categories=10, - images_file=f"{prefix}-images-idx3-ubyte.gz", - labels_file=f"{prefix}-labels-idx1-ubyte.gz", - ) - - -DATASET_MOCKS.update( - { - name: DatasetMock(name, mock_data_fn=mnist, configs=combinations_grid(split=("train", "test"))) - for name in ["mnist", "fashionmnist", "kmnist"] - } -) - - -@register_mock( - configs=combinations_grid( - split=("train", "test"), - image_set=("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), - ) -) -def emnist(root, config): - num_samples_map = {} - file_names = set() - for split, image_set in itertools.product( - ("train", "test"), - ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), - ): - prefix = f"emnist-{image_set.replace('_', '').lower()}-{split}" - images_file = f"{prefix}-images-idx3-ubyte.gz" - labels_file = f"{prefix}-labels-idx1-ubyte.gz" - file_names.update({images_file, labels_file}) - num_samples_map[(split, image_set)] = MNISTMockData.generate( - root, - # The image sets that merge some lower case letters in their respective upper case variant, still use dense - # labels in the data files. Thus, num_categories != len(categories) there. - num_categories=47 if config["image_set"] in ("Balanced", "By_Merge") else 62, - images_file=images_file, - labels_file=labels_file, - ) - - make_zip(root, "emnist-gzip.zip", *file_names) - - return num_samples_map[(config["split"], config["image_set"])] - - -@register_mock(configs=combinations_grid(split=("train", "test", "test10k", "test50k", "nist"))) -def qmnist(root, config): - num_categories = 10 - if config["split"] == "train": - num_samples = num_samples_gen = num_categories + 2 - prefix = "qmnist-train" - suffix = ".gz" - compressor = gzip.open - elif config["split"].startswith("test"): - # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create - # more than 10000 images for the dataset to not be empty. - num_samples_gen = 10001 - num_samples = { - "test": num_samples_gen, - "test10k": min(num_samples_gen, 10_000), - "test50k": num_samples_gen - 10_000, - }[config["split"]] - prefix = "qmnist-test" - suffix = ".gz" - compressor = gzip.open - else: # config["split"] == "nist" - num_samples = num_samples_gen = num_categories + 3 - prefix = "xnist" - suffix = ".xz" - compressor = lzma.open - - MNISTMockData.generate( - root, - num_categories=num_categories, - num_samples=num_samples_gen, - images_file=f"{prefix}-images-idx3-ubyte{suffix}", - labels_file=f"{prefix}-labels-idx2-int{suffix}", - label_size=(8,), - label_dtype=torch.int32, - compressor=compressor, - ) - return num_samples - - -class CIFARMockData: - NUM_PIXELS = 32 * 32 * 3 - - @classmethod - def _create_batch_file(cls, root, name, *, num_categories, labels_key, num_samples=1): - content = { - "data": make_tensor((num_samples, cls.NUM_PIXELS), dtype=torch.uint8).numpy(), - labels_key: torch.randint(0, num_categories, size=(num_samples,)).tolist(), - } - with open(pathlib.Path(root) / name, "wb") as fh: - pickle.dump(content, fh) - - @classmethod - def generate( - cls, - root, - name, - *, - folder, - train_files, - test_files, - num_categories, - labels_key, - ): - folder = root / folder - folder.mkdir() - files = (*train_files, *test_files) - for file in files: - cls._create_batch_file( - folder, - file, - num_categories=num_categories, - labels_key=labels_key, - ) - - make_tar(root, name, folder, compression="gz") - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def cifar10(root, config): - train_files = [f"data_batch_{idx}" for idx in range(1, 6)] - test_files = ["test_batch"] - - CIFARMockData.generate( - root=root, - name="cifar-10-python.tar.gz", - folder=pathlib.Path("cifar-10-batches-py"), - train_files=train_files, - test_files=test_files, - num_categories=10, - labels_key="labels", - ) - - return len(train_files if config["split"] == "train" else test_files) - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def cifar100(root, config): - train_files = ["train"] - test_files = ["test"] - - CIFARMockData.generate( - root=root, - name="cifar-100-python.tar.gz", - folder=pathlib.Path("cifar-100-python"), - train_files=train_files, - test_files=test_files, - num_categories=100, - labels_key="fine_labels", - ) - - return len(train_files if config["split"] == "train" else test_files) - - -@register_mock(configs=[dict()]) -def caltech101(root, config): - def create_ann_file(root, name): - import scipy.io - - box_coord = make_tensor((1, 4), dtype=torch.int32, low=0).numpy().astype(np.uint16) - obj_contour = make_tensor((2, int(torch.randint(3, 6, size=()))), dtype=torch.float64, low=0).numpy() - - scipy.io.savemat(str(pathlib.Path(root) / name), dict(box_coord=box_coord, obj_contour=obj_contour)) - - def create_ann_folder(root, name, file_name_fn, num_examples): - root = pathlib.Path(root) / name - root.mkdir(parents=True) - - for idx in range(num_examples): - create_ann_file(root, file_name_fn(idx)) - - images_root = root / "101_ObjectCategories" - anns_root = root / "Annotations" - - image_category_map = { - "Faces": "Faces_2", - "Faces_easy": "Faces_3", - "Motorbikes": "Motorbikes_16", - "airplanes": "Airplanes_Side_2", - } - - categories = ["Faces", "Faces_easy", "Motorbikes", "airplanes", "yin_yang"] - - num_images_per_category = 2 - for category in categories: - create_image_folder( - root=images_root, - name=category, - file_name_fn=lambda idx: f"image_{idx + 1:04d}.jpg", - num_examples=num_images_per_category, - ) - create_ann_folder( - root=anns_root, - name=image_category_map.get(category, category), - file_name_fn=lambda idx: f"annotation_{idx + 1:04d}.mat", - num_examples=num_images_per_category, - ) - - (images_root / "BACKGROUND_Goodle").mkdir() - make_tar(root, f"{images_root.name}.tar.gz", images_root, compression="gz") - - make_tar(root, f"{anns_root.name}.tar", anns_root) - - return num_images_per_category * len(categories) - - -@register_mock(configs=[dict()]) -def caltech256(root, config): - dir = root / "256_ObjectCategories" - num_images_per_category = 2 - - categories = [ - (1, "ak47"), - (127, "laptop-101"), - (198, "spider"), - (257, "clutter"), - ] - - for category_idx, category in categories: - files = create_image_folder( - dir, - name=f"{category_idx:03d}.{category}", - file_name_fn=lambda image_idx: f"{category_idx:03d}_{image_idx + 1:04d}.jpg", - num_examples=num_images_per_category, - ) - if category == "spider": - open(files[0].parent / "RENAME2", "w").close() - - make_tar(root, f"{dir.name}.tar", dir) - - return num_images_per_category * len(categories) - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def imagenet(root, config): - from scipy.io import savemat - - info = datasets.info("imagenet") - - if config["split"] == "train": - num_samples = len(info["wnids"]) - archive_name = "ILSVRC2012_img_train.tar" - - files = [] - for wnid in info["wnids"]: - create_image_folder( - root=root, - name=wnid, - file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG", - num_examples=1, - ) - files.append(make_tar(root, f"{wnid}.tar")) - elif config["split"] == "val": - num_samples = 3 - archive_name = "ILSVRC2012_img_val.tar" - files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)] - - devkit_root = root / "ILSVRC2012_devkit_t12" - data_root = devkit_root / "data" - data_root.mkdir(parents=True) - - with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: - for label in torch.randint(0, len(info["wnids"]), (num_samples,)).tolist(): - file.write(f"{label}\n") - - num_children = 0 - synsets = [ - (idx, wnid, category, "", num_children, [], 0, 0) - for idx, (category, wnid) in enumerate(zip(info["categories"], info["wnids"]), 1) - ] - num_children = 1 - synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5)) - synsets = np.array( - synsets, - dtype=np.dtype( - [ - ("ILSVRC2012_ID", "O"), - ("WNID", "O"), - ("words", "O"), - ("gloss", "O"), - ("num_children", "O"), - ("children", "O"), - ("wordnet_height", "O"), - ("num_train_images", "O"), - ] - ), - ) - savemat(data_root / "meta.mat", dict(synsets=synsets)) - - make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") - else: # config["split"] == "test" - num_samples = 5 - archive_name = "ILSVRC2012_img_test_v10102019.tar" - files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)] - - make_tar(root, archive_name, *files) - - return num_samples - - -class CocoMockData: - @classmethod - def _make_annotations_json( - cls, - root, - name, - *, - images_meta, - fn, - ): - num_anns_per_image = torch.randint(1, 5, (len(images_meta),)) - num_anns_total = int(num_anns_per_image.sum()) - ann_ids_iter = iter(torch.arange(num_anns_total)[torch.randperm(num_anns_total)]) - - anns_meta = [] - for image_meta, num_anns in zip(images_meta, num_anns_per_image): - for _ in range(num_anns): - ann_id = int(next(ann_ids_iter)) - anns_meta.append(dict(fn(ann_id, image_meta), id=ann_id, image_id=image_meta["id"])) - anns_meta.sort(key=lambda ann: ann["id"]) - - with open(root / name, "w") as file: - json.dump(dict(images=images_meta, annotations=anns_meta), file) - - return num_anns_per_image - - @staticmethod - def _make_instances_data(ann_id, image_meta): - def make_rle_segmentation(): - height, width = image_meta["height"], image_meta["width"] - numel = height * width - counts = [] - while sum(counts) <= numel: - counts.append(int(torch.randint(5, 8, ()))) - if sum(counts) > numel: - counts[-1] -= sum(counts) - numel - return dict(counts=counts, size=[height, width]) - - return dict( - segmentation=make_rle_segmentation(), - bbox=make_tensor((4,), dtype=torch.float32, low=0).tolist(), - iscrowd=True, - area=float(make_scalar(dtype=torch.float32)), - category_id=int(make_scalar(dtype=torch.int64)), - ) - - @staticmethod - def _make_captions_data(ann_id, image_meta): - return dict(caption=f"Caption {ann_id} describing image {image_meta['id']}.") - - @classmethod - def _make_annotations(cls, root, name, *, images_meta): - num_anns_per_image = torch.zeros((len(images_meta),), dtype=torch.int64) - for annotations, fn in ( - ("instances", cls._make_instances_data), - ("captions", cls._make_captions_data), - ): - num_anns_per_image += cls._make_annotations_json( - root, f"{annotations}_{name}.json", images_meta=images_meta, fn=fn - ) - - return int(num_anns_per_image.sum()) - - @classmethod - def generate( - cls, - root, - *, - split, - year, - num_samples, - ): - annotations_dir = root / "annotations" - annotations_dir.mkdir() - - for split_ in ("train", "val"): - config_name = f"{split_}{year}" - - images_meta = [ - dict( - file_name=f"{idx:012d}.jpg", - id=idx, - width=width, - height=height, - ) - for idx, (height, width) in enumerate( - torch.randint(3, 11, size=(num_samples, 2), dtype=torch.int).tolist() - ) - ] - - if split_ == split: - create_image_folder( - root, - config_name, - file_name_fn=lambda idx: images_meta[idx]["file_name"], - num_examples=num_samples, - size=lambda idx: (3, images_meta[idx]["height"], images_meta[idx]["width"]), - ) - make_zip(root, f"{config_name}.zip") - - cls._make_annotations( - annotations_dir, - config_name, - images_meta=images_meta, - ) - - make_zip(root, f"annotations_trainval{year}.zip", annotations_dir) - - return num_samples - - -@register_mock( - configs=combinations_grid( - split=("train", "val"), - year=("2017", "2014"), - annotations=("instances", "captions", None), - ) -) -def coco(root, config): - return CocoMockData.generate(root, split=config["split"], year=config["year"], num_samples=5) - - -class SBDMockData: - _NUM_CATEGORIES = 20 - - @classmethod - def _make_split_files(cls, root_map, *, split): - splits_and_idcs = [ - ("train", [0, 1, 2]), - ("val", [3]), - ] - if split == "train_noval": - splits_and_idcs.append(("train_noval", [0, 2])) - - ids_map = {split: [f"2008_{idx:06d}" for idx in idcs] for split, idcs in splits_and_idcs} - - for split, ids in ids_map.items(): - with open(root_map[split] / f"{split}.txt", "w") as fh: - fh.writelines(f"{id}\n" for id in ids) - - return sorted(set(itertools.chain(*ids_map.values()))), {split: len(ids) for split, ids in ids_map.items()} - - @classmethod - def _make_anns_folder(cls, root, name, ids): - from scipy.io import savemat - - anns_folder = root / name - anns_folder.mkdir() - - sizes = torch.randint(1, 9, size=(len(ids), 2)).tolist() - for id, size in zip(ids, sizes): - savemat( - anns_folder / f"{id}.mat", - { - "GTcls": { - "Boundaries": cls._make_boundaries(size), - "Segmentation": cls._make_segmentation(size), - } - }, - ) - return sizes - - @classmethod - def _make_boundaries(cls, size): - from scipy.sparse import csc_matrix - - return [ - [csc_matrix(torch.randint(0, 2, size=size, dtype=torch.uint8).numpy())] for _ in range(cls._NUM_CATEGORIES) - ] - - @classmethod - def _make_segmentation(cls, size): - return torch.randint(0, cls._NUM_CATEGORIES + 1, size=size, dtype=torch.uint8).numpy() - - @classmethod - def generate(cls, root, *, split): - archive_folder = root / "benchmark_RELEASE" - dataset_folder = archive_folder / "dataset" - dataset_folder.mkdir(parents=True, exist_ok=True) - - ids, num_samples_map = cls._make_split_files( - defaultdict(lambda: dataset_folder, {"train_noval": root}), split=split - ) - sizes = cls._make_anns_folder(dataset_folder, "cls", ids) - create_image_folder( - dataset_folder, "img", lambda idx: f"{ids[idx]}.jpg", num_examples=len(ids), size=lambda idx: sizes[idx] - ) - - make_tar(root, "benchmark.tgz", archive_folder, compression="gz") - - return num_samples_map[split] - - -@register_mock(configs=combinations_grid(split=("train", "val", "train_noval"))) -def sbd(root, config): - return SBDMockData.generate(root, split=config["split"]) - - -@register_mock(configs=[dict()]) -def semeion(root, config): - num_samples = 3 - num_categories = 10 - - images = torch.rand(num_samples, 256) - labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories) - with open(root / "semeion.data", "w") as fh: - for image, one_hot_label in zip(images, labels): - image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image]) - labels_columns = " ".join([str(label.item()) for label in one_hot_label]) - fh.write(f"{image_columns} {labels_columns} \n") - - return num_samples - - -class VOCMockData: - _TRAIN_VAL_FILE_NAMES = { - "2007": "VOCtrainval_06-Nov-2007.tar", - "2008": "VOCtrainval_14-Jul-2008.tar", - "2009": "VOCtrainval_11-May-2009.tar", - "2010": "VOCtrainval_03-May-2010.tar", - "2011": "VOCtrainval_25-May-2011.tar", - "2012": "VOCtrainval_11-May-2012.tar", - } - _TEST_FILE_NAMES = { - "2007": "VOCtest_06-Nov-2007.tar", - } - - @classmethod - def _make_split_files(cls, root, *, year, trainval): - split_folder = root / "ImageSets" - - if trainval: - idcs_map = { - "train": [0, 1, 2], - "val": [3, 4], - } - idcs_map["trainval"] = [*idcs_map["train"], *idcs_map["val"]] - else: - idcs_map = { - "test": [5], - } - ids_map = {split: [f"{year}_{idx:06d}" for idx in idcs] for split, idcs in idcs_map.items()} - - for task_sub_folder in ("Main", "Segmentation"): - task_folder = split_folder / task_sub_folder - task_folder.mkdir(parents=True, exist_ok=True) - for split, ids in ids_map.items(): - with open(task_folder / f"{split}.txt", "w") as fh: - fh.writelines(f"{id}\n" for id in ids) - - return sorted(set(itertools.chain(*ids_map.values()))), {split: len(ids) for split, ids in ids_map.items()} - - @classmethod - def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples): - folder = root / name - folder.mkdir(parents=True, exist_ok=True) - - for idx in range(num_examples): - cls._make_detection_ann_file(folder, file_name_fn(idx)) - - @classmethod - def _make_detection_ann_file(cls, root, name): - def add_child(parent, name, text=None): - child = ET.SubElement(parent, name) - child.text = str(text) - return child - - def add_name(obj, name="dog"): - add_child(obj, "name", name) - - def add_size(obj): - obj = add_child(obj, "size") - size = {"width": 0, "height": 0, "depth": 3} - for name, text in size.items(): - add_child(obj, name, text) - - def add_bndbox(obj): - obj = add_child(obj, "bndbox") - bndbox = {"xmin": 1, "xmax": 2, "ymin": 3, "ymax": 4} - for name, text in bndbox.items(): - add_child(obj, name, text) - - annotation = ET.Element("annotation") - add_size(annotation) - obj = add_child(annotation, "object") - add_name(obj) - add_bndbox(obj) - - with open(root / name, "wb") as fh: - fh.write(ET.tostring(annotation)) - - @classmethod - def generate(cls, root, *, year, trainval): - archive_folder = root - if year == "2011": - archive_folder = root / "TrainVal" - data_folder = archive_folder / "VOCdevkit" - else: - archive_folder = data_folder = root / "VOCdevkit" - data_folder = data_folder / f"VOC{year}" - data_folder.mkdir(parents=True, exist_ok=True) - - ids, num_samples_map = cls._make_split_files(data_folder, year=year, trainval=trainval) - for make_folder_fn, name, suffix in [ - (create_image_folder, "JPEGImages", ".jpg"), - (create_image_folder, "SegmentationClass", ".png"), - (cls._make_detection_anns_folder, "Annotations", ".xml"), - ]: - make_folder_fn(data_folder, name, file_name_fn=lambda idx: ids[idx] + suffix, num_examples=len(ids)) - make_tar(root, (cls._TRAIN_VAL_FILE_NAMES if trainval else cls._TEST_FILE_NAMES)[year], archive_folder) - - return num_samples_map - - -@register_mock( - configs=[ - *combinations_grid( - split=("train", "val", "trainval"), - year=("2007", "2008", "2009", "2010", "2011", "2012"), - task=("detection", "segmentation"), - ), - *combinations_grid( - split=("test",), - year=("2007",), - task=("detection", "segmentation"), - ), - ], -) -def voc(root, config): - trainval = config["split"] != "test" - return VOCMockData.generate(root, year=config["year"], trainval=trainval)[config["split"]] - - -class CelebAMockData: - @classmethod - def _make_ann_file(cls, root, name, data, *, field_names=None): - with open(root / name, "w") as file: - if field_names: - file.write(f"{len(data)}\r\n") - file.write(" ".join(field_names) + "\r\n") - file.writelines(" ".join(str(item) for item in row) + "\r\n" for row in data) - - _SPLIT_TO_IDX = { - "train": 0, - "val": 1, - "test": 2, - } - - @classmethod - def _make_split_file(cls, root): - num_samples_map = {"train": 4, "val": 3, "test": 2} - - data = [ - (f"{idx:06d}.jpg", cls._SPLIT_TO_IDX[split]) - for split, num_samples in num_samples_map.items() - for idx in range(num_samples) - ] - cls._make_ann_file(root, "list_eval_partition.txt", data) - - image_file_names, _ = zip(*data) - return image_file_names, num_samples_map - - @classmethod - def _make_identity_file(cls, root, image_file_names): - cls._make_ann_file( - root, "identity_CelebA.txt", [(name, int(make_scalar(low=1, dtype=torch.int))) for name in image_file_names] - ) - - @classmethod - def _make_attributes_file(cls, root, image_file_names): - field_names = ("5_o_Clock_Shadow", "Young") - data = [ - [name, *[" 1" if attr else "-1" for attr in make_tensor((len(field_names),), dtype=torch.bool)]] - for name in image_file_names - ] - cls._make_ann_file(root, "list_attr_celeba.txt", data, field_names=(*field_names, "")) - - @classmethod - def _make_bounding_boxes_file(cls, root, image_file_names): - field_names = ("image_id", "x_1", "y_1", "width", "height") - data = [ - [f"{name} ", *[f"{coord:3d}" for coord in make_tensor((4,), low=0, dtype=torch.int).tolist()]] - for name in image_file_names - ] - cls._make_ann_file(root, "list_bbox_celeba.txt", data, field_names=field_names) - - @classmethod - def _make_landmarks_file(cls, root, image_file_names): - field_names = ("lefteye_x", "lefteye_y", "rightmouth_x", "rightmouth_y") - data = [ - [ - name, - *[ - f"{coord:4d}" if idx else coord - for idx, coord in enumerate(make_tensor((len(field_names),), low=0, dtype=torch.int).tolist()) - ], - ] - for name in image_file_names - ] - cls._make_ann_file(root, "list_landmarks_align_celeba.txt", data, field_names=field_names) - - @classmethod - def generate(cls, root): - image_file_names, num_samples_map = cls._make_split_file(root) - - image_files = create_image_folder( - root, "img_align_celeba", file_name_fn=lambda idx: image_file_names[idx], num_examples=len(image_file_names) - ) - make_zip(root, image_files[0].parent.with_suffix(".zip").name) - - for make_ann_file_fn in ( - cls._make_identity_file, - cls._make_attributes_file, - cls._make_bounding_boxes_file, - cls._make_landmarks_file, - ): - make_ann_file_fn(root, image_file_names) - - return num_samples_map - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def celeba(root, config): - return CelebAMockData.generate(root)[config["split"]] - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def country211(root, config): - split_folder = pathlib.Path(root, "country211", "valid" if config["split"] == "val" else config["split"]) - split_folder.mkdir(parents=True, exist_ok=True) - - num_examples = { - "train": 3, - "val": 4, - "test": 5, - }[config["split"]] - - classes = ("AD", "BS", "GR") - for cls in classes: - create_image_folder( - split_folder, - name=cls, - file_name_fn=lambda idx: f"{idx}.jpg", - num_examples=num_examples, - ) - make_tar(root, f"{split_folder.parent.name}.tgz", split_folder.parent, compression="gz") - return num_examples * len(classes) - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def food101(root, config): - data_folder = root / "food-101" - - num_images_per_class = 3 - image_folder = data_folder / "images" - categories = ["apple_pie", "baby_back_ribs", "waffles"] - image_ids = [] - for category in categories: - image_files = create_image_folder( - image_folder, - category, - file_name_fn=lambda idx: f"{idx:04d}.jpg", - num_examples=num_images_per_class, - ) - image_ids.extend(path.relative_to(path.parents[1]).with_suffix("").as_posix() for path in image_files) - - meta_folder = data_folder / "meta" - meta_folder.mkdir() - - with open(meta_folder / "classes.txt", "w") as file: - for category in categories: - file.write(f"{category}\n") - - splits = ["train", "test"] - num_samples_map = {} - for offset, split in enumerate(splits): - image_ids_in_split = image_ids[offset :: len(splits)] - num_samples_map[split] = len(image_ids_in_split) - with open(meta_folder / f"{split}.txt", "w") as file: - for image_id in image_ids_in_split: - file.write(f"{image_id}\n") - - make_tar(root, f"{data_folder.name}.tar.gz", compression="gz") - - return num_samples_map[config["split"]] - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"), fold=(1, 4, 10))) -def dtd(root, config): - data_folder = root / "dtd" - - num_images_per_class = 3 - image_folder = data_folder / "images" - categories = {"banded", "marbled", "zigzagged"} - image_ids_per_category = { - category: [ - str(path.relative_to(path.parents[1]).as_posix()) - for path in create_image_folder( - image_folder, - category, - file_name_fn=lambda idx: f"{category}_{idx:04d}.jpg", - num_examples=num_images_per_class, - ) - ] - for category in categories - } - - meta_folder = data_folder / "labels" - meta_folder.mkdir() - - with open(meta_folder / "labels_joint_anno.txt", "w") as file: - for cls, image_ids in image_ids_per_category.items(): - for image_id in image_ids: - joint_categories = random.choices( - list(categories - {cls}), k=int(torch.randint(len(categories) - 1, ())) - ) - file.write(" ".join([image_id, *sorted([cls, *joint_categories])]) + "\n") - - image_ids = list(itertools.chain(*image_ids_per_category.values())) - splits = ("train", "val", "test") - num_samples_map = {} - for fold in range(1, 11): - random.shuffle(image_ids) - for offset, split in enumerate(splits): - image_ids_in_config = image_ids[offset :: len(splits)] - with open(meta_folder / f"{split}{fold}.txt", "w") as file: - file.write("\n".join(image_ids_in_config) + "\n") - - num_samples_map[(split, fold)] = len(image_ids_in_config) - - make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz") - - return num_samples_map[config["split"], config["fold"]] - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def fer2013(root, config): - split = config["split"] - num_samples = 5 if split == "train" else 3 - - path = root / f"{split}.csv" - with open(path, "w", newline="") as file: - field_names = ["emotion"] if split == "train" else [] - field_names.append("pixels") - - file.write(",".join(field_names) + "\n") - - writer = csv.DictWriter(file, fieldnames=field_names, quotechar='"', quoting=csv.QUOTE_NONNUMERIC) - for _ in range(num_samples): - rowdict = { - "pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)]) - } - if split == "train": - rowdict["emotion"] = int(torch.randint(7, ())) - writer.writerow(rowdict) - - make_zip(root, f"{path.name}.zip", path) - - return num_samples - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def gtsrb(root, config): - num_examples_per_class = 5 if config["split"] == "train" else 3 - classes = ("00000", "00042", "00012") - num_examples = num_examples_per_class * len(classes) - - csv_columns = ["Filename", "Width", "Height", "Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2", "ClassId"] - - def _make_ann_file(path, num_examples, class_idx): - if class_idx == "random": - class_idx = torch.randint(1, len(classes) + 1, size=(1,)).item() - - with open(path, "w") as csv_file: - writer = csv.DictWriter(csv_file, fieldnames=csv_columns, delimiter=";") - writer.writeheader() - for image_idx in range(num_examples): - writer.writerow( - { - "Filename": f"{image_idx:05d}.ppm", - "Width": torch.randint(1, 100, size=()).item(), - "Height": torch.randint(1, 100, size=()).item(), - "Roi.X1": torch.randint(1, 100, size=()).item(), - "Roi.Y1": torch.randint(1, 100, size=()).item(), - "Roi.X2": torch.randint(1, 100, size=()).item(), - "Roi.Y2": torch.randint(1, 100, size=()).item(), - "ClassId": class_idx, - } - ) - - archive_folder = root / "GTSRB" - - if config["split"] == "train": - train_folder = archive_folder / "Training" - train_folder.mkdir(parents=True) - - for class_idx in classes: - create_image_folder( - train_folder, - name=class_idx, - file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm", - num_examples=num_examples_per_class, - ) - _make_ann_file( - path=train_folder / class_idx / f"GT-{class_idx}.csv", - num_examples=num_examples_per_class, - class_idx=int(class_idx), - ) - make_zip(root, "GTSRB-Training_fixed.zip", archive_folder) - else: - test_folder = archive_folder / "Final_Test" - test_folder.mkdir(parents=True) - - create_image_folder( - test_folder, - name="Images", - file_name_fn=lambda image_idx: f"{image_idx:05d}.ppm", - num_examples=num_examples, - ) - - make_zip(root, "GTSRB_Final_Test_Images.zip", archive_folder) - - _make_ann_file( - path=root / "GT-final_test.csv", - num_examples=num_examples, - class_idx="random", - ) - - make_zip(root, "GTSRB_Final_Test_GT.zip", "GT-final_test.csv") - - return num_examples - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def clevr(root, config): - data_folder = root / "CLEVR_v1.0" - - num_samples_map = { - "train": 3, - "val": 2, - "test": 1, - } - - images_folder = data_folder / "images" - image_files = { - split: create_image_folder( - images_folder, - split, - file_name_fn=lambda idx: f"CLEVR_{split}_{idx:06d}.jpg", - num_examples=num_samples, - ) - for split, num_samples in num_samples_map.items() - } - - scenes_folder = data_folder / "scenes" - scenes_folder.mkdir() - for split in ["train", "val"]: - with open(scenes_folder / f"CLEVR_{split}_scenes.json", "w") as file: - json.dump( - { - "scenes": [ - { - "image_filename": image_file.name, - # We currently only return the number of objects in a scene. - # Thus, it is sufficient for now to only mock the number of elements. - "objects": [None] * int(torch.randint(1, 5, ())), - } - for image_file in image_files[split] - ] - }, - file, - ) - - make_zip(root, f"{data_folder.name}.zip", data_folder) - - return num_samples_map[config["split"]] - - -class OxfordIIITPetMockData: - @classmethod - def _meta_to_split_and_classification_ann(cls, meta, idx): - image_id = "_".join( - [ - *[(str.title if meta["species"] == "cat" else str.lower)(part) for part in meta["cls"].split()], - str(idx), - ] - ) - class_id = str(meta["label"] + 1) - species = "1" if meta["species"] == "cat" else "2" - breed_id = "-1" - return (image_id, class_id, species, breed_id) - - @classmethod - def generate(self, root): - classification_anns_meta = ( - dict(cls="Abyssinian", label=0, species="cat"), - dict(cls="Keeshond", label=18, species="dog"), - dict(cls="Yorkshire Terrier", label=36, species="dog"), - ) - split_and_classification_anns = [ - self._meta_to_split_and_classification_ann(meta, idx) - for meta, idx in itertools.product(classification_anns_meta, (1, 2, 10)) - ] - image_ids, *_ = zip(*split_and_classification_anns) - - image_files = create_image_folder( - root, "images", file_name_fn=lambda idx: f"{image_ids[idx]}.jpg", num_examples=len(image_ids) - ) - - anns_folder = root / "annotations" - anns_folder.mkdir() - random.shuffle(split_and_classification_anns) - splits = ("trainval", "test") - num_samples_map = {} - for offset, split in enumerate(splits): - split_and_classification_anns_in_split = split_and_classification_anns[offset :: len(splits)] - with open(anns_folder / f"{split}.txt", "w") as file: - writer = csv.writer(file, delimiter=" ") - for split_and_classification_ann in split_and_classification_anns_in_split: - writer.writerow(split_and_classification_ann) - - num_samples_map[split] = len(split_and_classification_anns_in_split) - - segmentation_files = create_image_folder( - anns_folder, "trimaps", file_name_fn=lambda idx: f"{image_ids[idx]}.png", num_examples=len(image_ids) - ) - - # The dataset has some rogue files - for path in image_files[:3]: - path.with_suffix(".mat").touch() - for path in segmentation_files: - path.with_name(f".{path.name}").touch() - - make_tar(root, "images.tar.gz", compression="gz") - make_tar(root, anns_folder.with_suffix(".tar.gz").name, compression="gz") - - return num_samples_map - - -@register_mock(name="oxford-iiit-pet", configs=combinations_grid(split=("trainval", "test"))) -def oxford_iiit_pet(root, config): - return OxfordIIITPetMockData.generate(root)[config["split"]] - - -class _CUB200MockData: - @classmethod - def _category_folder(cls, category, idx): - return f"{idx:03d}.{category}" - - @classmethod - def _file_stem(cls, category, idx): - return f"{category}_{idx:04d}" - - @classmethod - def _make_images(cls, images_folder): - image_files = [] - for category_idx, category in [ - (1, "Black_footed_Albatross"), - (100, "Brown_Pelican"), - (200, "Common_Yellowthroat"), - ]: - image_files.extend( - create_image_folder( - images_folder, - cls._category_folder(category, category_idx), - lambda image_idx: f"{cls._file_stem(category, image_idx)}.jpg", - num_examples=5, - ) - ) - - return image_files - - -class CUB2002011MockData(_CUB200MockData): - @classmethod - def _make_archive(cls, root): - archive_folder = root / "CUB_200_2011" - - images_folder = archive_folder / "images" - image_files = cls._make_images(images_folder) - image_ids = list(range(1, len(image_files) + 1)) - - with open(archive_folder / "images.txt", "w") as file: - file.write( - "\n".join( - f"{id} {path.relative_to(images_folder).as_posix()}" for id, path in zip(image_ids, image_files) - ) - ) - - split_ids = torch.randint(2, (len(image_ids),)).tolist() - counts = Counter(split_ids) - num_samples_map = {"train": counts[1], "test": counts[0]} - with open(archive_folder / "train_test_split.txt", "w") as file: - file.write("\n".join(f"{image_id} {split_id}" for image_id, split_id in zip(image_ids, split_ids))) - - with open(archive_folder / "bounding_boxes.txt", "w") as file: - file.write( - "\n".join( - " ".join( - str(item) - for item in [image_id, *make_tensor((4,), dtype=torch.int, low=0).to(torch.float).tolist()] - ) - for image_id in image_ids - ) - ) - - make_tar(root, archive_folder.with_suffix(".tgz").name, compression="gz") - - return image_files, num_samples_map - - @classmethod - def _make_segmentations(cls, root, image_files): - segmentations_folder = root / "segmentations" - for image_file in image_files: - folder = segmentations_folder.joinpath(image_file.relative_to(image_file.parents[1])) - folder.mkdir(exist_ok=True, parents=True) - create_image_file( - folder, - image_file.with_suffix(".png").name, - size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()], - ) - - make_tar(root, segmentations_folder.with_suffix(".tgz").name, compression="gz") - - @classmethod - def generate(cls, root): - image_files, num_samples_map = cls._make_archive(root) - cls._make_segmentations(root, image_files) - return num_samples_map - - -class CUB2002010MockData(_CUB200MockData): - @classmethod - def _make_hidden_rouge_file(cls, *files): - for file in files: - (file.parent / f"._{file.name}").touch() - - @classmethod - def _make_splits(cls, root, image_files): - split_folder = root / "lists" - split_folder.mkdir() - random.shuffle(image_files) - splits = ("train", "test") - num_samples_map = {} - for offset, split in enumerate(splits): - image_files_in_split = image_files[offset :: len(splits)] - - split_file = split_folder / f"{split}.txt" - with open(split_file, "w") as file: - file.write( - "\n".join( - sorted( - str(image_file.relative_to(image_file.parents[1]).as_posix()) - for image_file in image_files_in_split - ) - ) - ) - - cls._make_hidden_rouge_file(split_file) - num_samples_map[split] = len(image_files_in_split) - - make_tar(root, split_folder.with_suffix(".tgz").name, compression="gz") - - return num_samples_map - - @classmethod - def _make_anns(cls, root, image_files): - from scipy.io import savemat - - anns_folder = root / "annotations-mat" - for image_file in image_files: - ann_file = anns_folder / image_file.with_suffix(".mat").relative_to(image_file.parents[1]) - ann_file.parent.mkdir(parents=True, exist_ok=True) - - savemat( - ann_file, - { - "seg": torch.randint( - 256, make_tensor((2,), low=3, dtype=torch.int).tolist(), dtype=torch.uint8 - ).numpy(), - "bbox": dict( - zip(("left", "top", "right", "bottom"), make_tensor((4,), dtype=torch.uint8).tolist()) - ), - }, - ) - - readme_file = anns_folder / "README.txt" - readme_file.touch() - cls._make_hidden_rouge_file(readme_file) - - make_tar(root, "annotations.tgz", anns_folder, compression="gz") - - @classmethod - def generate(cls, root): - images_folder = root / "images" - image_files = cls._make_images(images_folder) - cls._make_hidden_rouge_file(*image_files) - make_tar(root, images_folder.with_suffix(".tgz").name, compression="gz") - - num_samples_map = cls._make_splits(root, image_files) - cls._make_anns(root, image_files) - - return num_samples_map - - -@register_mock(configs=combinations_grid(split=("train", "test"), year=("2010", "2011"))) -def cub200(root, config): - num_samples_map = (CUB2002011MockData if config["year"] == "2011" else CUB2002010MockData).generate(root) - return num_samples_map[config["split"]] - - -@register_mock(configs=[dict()]) -def eurosat(root, config): - data_folder = root / "2750" - data_folder.mkdir(parents=True) - - num_examples_per_class = 3 - categories = ["AnnualCrop", "Forest"] - for category in categories: - create_image_folder( - root=data_folder, - name=category, - file_name_fn=lambda idx: f"{category}_{idx + 1}.jpg", - num_examples=num_examples_per_class, - ) - make_zip(root, "EuroSAT.zip", data_folder) - return len(categories) * num_examples_per_class - - -@register_mock(configs=combinations_grid(split=("train", "test", "extra"))) -def svhn(root, config): - import scipy.io as sio - - num_samples = { - "train": 2, - "test": 3, - "extra": 4, - }[config["split"]] - - sio.savemat( - root / f"{config['split']}_32x32.mat", - { - "X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8), - "y": np.random.randint(10, size=(num_samples,), dtype=np.uint8), - }, - ) - return num_samples - - -@register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def pcam(root, config): - import h5py - - num_images = {"train": 2, "test": 3, "val": 4}[config["split"]] - - split = "valid" if config["split"] == "val" else config["split"] - - images_io = io.BytesIO() - with h5py.File(images_io, "w") as f: - f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8) - - targets_io = io.BytesIO() - with h5py.File(targets_io, "w") as f: - f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8) - - # Create .gz compressed files - images_file = root / f"camelyonpatch_level_2_split_{split}_x.h5.gz" - targets_file = root / f"camelyonpatch_level_2_split_{split}_y.h5.gz" - for compressed_file_name, uncompressed_file_io in ((images_file, images_io), (targets_file, targets_io)): - compressed_data = gzip.compress(uncompressed_file_io.getbuffer()) - with open(compressed_file_name, "wb") as compressed_file: - compressed_file.write(compressed_data) - - return num_images - - -@register_mock(name="stanford-cars", configs=combinations_grid(split=("train", "test"))) -def stanford_cars(root, config): - import scipy.io as io - from numpy.core.records import fromarrays - - split = config["split"] - num_samples = {"train": 5, "test": 7}[split] - num_categories = 3 - - if split == "train": - images_folder_name = "cars_train" - devkit = root / "devkit" - devkit.mkdir() - annotations_mat_path = devkit / "cars_train_annos.mat" - else: - images_folder_name = "cars_test" - annotations_mat_path = root / "cars_test_annos_withlabels.mat" - - create_image_folder( - root=root, - name=images_folder_name, - file_name_fn=lambda image_index: f"{image_index:5d}.jpg", - num_examples=num_samples, - ) - - make_tar(root, f"cars_{split}.tgz", images_folder_name) - bbox = np.random.randint(1, 200, num_samples, dtype=np.uint8) - classes = np.random.randint(1, num_categories + 1, num_samples, dtype=np.uint8) - fnames = [f"{i:5d}.jpg" for i in range(num_samples)] - rec_array = fromarrays( - [bbox, bbox, bbox, bbox, classes, fnames], - names=["bbox_x1", "bbox_y1", "bbox_x2", "bbox_y2", "class", "fname"], - ) - - io.savemat(annotations_mat_path, {"annotations": rec_array}) - if split == "train": - make_tar(root, "car_devkit.tgz", devkit, compression="gz") - - return num_samples - - -@register_mock(configs=combinations_grid(split=("train", "test"))) -def usps(root, config): - num_samples = {"train": 15, "test": 7}[config["split"]] - - with bz2.open(root / f"usps{'.t' if not config['split'] == 'train' else ''}.bz2", "wb") as fh: - lines = [] - for _ in range(num_samples): - label = make_tensor(1, low=1, high=11, dtype=torch.int) - values = make_tensor(256, low=-1, high=1, dtype=torch.float) - lines.append( - " ".join([f"{int(label)}", *(f"{idx}:{float(value):.6f}" for idx, value in enumerate(values, 1))]) - ) - - fh.write("\n".join(lines).encode()) - - return num_samples diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py deleted file mode 100644 index 8259246c0cb..00000000000 --- a/test/prototype_common_utils.py +++ /dev/null @@ -1,82 +0,0 @@ -import collections.abc -import dataclasses -from typing import Optional, Sequence - -import pytest -import torch - -from common_utils import combinations_grid, DEFAULT_EXTRA_DIMS, from_loader, from_loaders, TensorLoader -from torch.nn.functional import one_hot - -from torchvision.prototype import datapoints - - -@dataclasses.dataclass -class LabelLoader(TensorLoader): - categories: Optional[Sequence[str]] - - -def _parse_categories(categories): - if categories is None: - num_categories = int(torch.randint(1, 11, ())) - elif isinstance(categories, int): - num_categories = categories - categories = [f"category{idx}" for idx in range(num_categories)] - elif isinstance(categories, collections.abc.Sequence) and all(isinstance(category, str) for category in categories): - categories = list(categories) - num_categories = len(categories) - else: - raise pytest.UsageError( - f"`categories` can either be `None` (default), an integer, or a sequence of strings, " - f"but got '{categories}' instead." - ) - return categories, num_categories - - -def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64): - categories, num_categories = _parse_categories(categories) - - def fn(shape, dtype, device): - # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, - # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 - data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) - return datapoints.Label(data, categories=categories) - - return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) - - -make_label = from_loader(make_label_loader) - - -@dataclasses.dataclass -class OneHotLabelLoader(TensorLoader): - categories: Optional[Sequence[str]] - - -def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int64): - categories, num_categories = _parse_categories(categories) - - def fn(shape, dtype, device): - if num_categories == 0: - data = torch.empty(shape, dtype=dtype, device=device) - else: - # The idiom `make_label_loader(..., dtype=torch.int64); ...; one_hot(...).to(dtype)` is intentional - # since `one_hot` only supports int64 - label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) - data = one_hot(label, num_classes=num_categories).to(dtype) - return datapoints.OneHotLabel(data, categories=categories) - - return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) - - -def make_one_hot_label_loaders( - *, - categories=(1, 0, None), - extra_dims=DEFAULT_EXTRA_DIMS, - dtypes=(torch.int64, torch.float32), -): - for params in combinations_grid(categories=categories, extra_dims=extra_dims, dtype=dtypes): - yield make_one_hot_label_loader(**params) - - -make_one_hot_labels = from_loaders(make_one_hot_label_loaders) diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py deleted file mode 100644 index 04e3cd67f96..00000000000 --- a/test/test_prototype_datapoints.py +++ /dev/null @@ -1,133 +0,0 @@ -import pytest -import torch - -from torchvision.prototype import datapoints as proto_datapoints - - -@pytest.mark.parametrize( - ("data", "input_requires_grad", "expected_requires_grad"), - [ - ([0.0], None, False), - ([0.0], False, False), - ([0.0], True, True), - (torch.tensor([0.0], requires_grad=False), None, False), - (torch.tensor([0.0], requires_grad=False), False, False), - (torch.tensor([0.0], requires_grad=False), True, True), - (torch.tensor([0.0], requires_grad=True), None, True), - (torch.tensor([0.0], requires_grad=True), False, False), - (torch.tensor([0.0], requires_grad=True), True, True), - ], -) -def test_new_requires_grad(data, input_requires_grad, expected_requires_grad): - datapoint = proto_datapoints.Label(data, requires_grad=input_requires_grad) - assert datapoint.requires_grad is expected_requires_grad - - -def test_isinstance(): - assert isinstance( - proto_datapoints.Label([0, 1, 0], categories=["foo", "bar"]), - torch.Tensor, - ) - - -def test_wrapping_no_copy(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) - - assert label.data_ptr() == tensor.data_ptr() - - -def test_to_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) - - label_to = label.to(torch.int32) - - assert type(label_to) is proto_datapoints.Label - assert label_to.dtype is torch.int32 - assert label_to.categories is label.categories - - -def test_to_datapoint_reference(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = proto_datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32) - - tensor_to = tensor.to(label) - - assert type(tensor_to) is torch.Tensor - assert tensor_to.dtype is torch.int32 - - -def test_clone_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) - - label_clone = label.clone() - - assert type(label_clone) is proto_datapoints.Label - assert label_clone.data_ptr() != label.data_ptr() - assert label_clone.categories is label.categories - - -def test_requires_grad__wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.float32) - label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) - - assert not label.requires_grad - - label_requires_grad = label.requires_grad_(True) - - assert type(label_requires_grad) is proto_datapoints.Label - assert label.requires_grad - assert label_requires_grad.requires_grad - - -def test_other_op_no_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) - - # any operation besides .to() and .clone() will do here - output = label * 2 - - assert type(output) is torch.Tensor - - -@pytest.mark.parametrize( - "op", - [ - lambda t: t.numpy(), - lambda t: t.tolist(), - lambda t: t.max(dim=-1), - ], -) -def test_no_tensor_output_op_no_wrapping(op): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) - - output = op(label) - - assert type(output) is not proto_datapoints.Label - - -def test_inplace_op_no_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) - - output = label.add_(0) - - assert type(output) is torch.Tensor - assert type(label) is proto_datapoints.Label - - -def test_wrap_like(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = proto_datapoints.Label(tensor, categories=["foo", "bar"]) - - # any operation besides .to() and .clone() will do here - output = label * 2 - - label_new = proto_datapoints.Label.wrap_like(label, output) - - assert type(label_new) is proto_datapoints.Label - assert label_new.data_ptr() == output.data_ptr() - assert label_new.categories is label.categories diff --git a/test/test_prototype_datasets_builtin.py b/test/test_prototype_datasets_builtin.py deleted file mode 100644 index 4848e799f04..00000000000 --- a/test/test_prototype_datasets_builtin.py +++ /dev/null @@ -1,282 +0,0 @@ -import io -import pickle -from collections import deque -from pathlib import Path - -import pytest -import torch -import torchvision.transforms.v2 as transforms - -from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks -from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair - -# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish -from torch.utils.data import DataLoader - -# TODO: replace with torchdata equivalent as soon as it is available -from torch.utils.data.graph_settings import get_all_graph_pipes - -from torchdata.dataloader2.graph.utils import traverse_dps -from torchdata.datapipes.iter import ShardingFilter, Shuffler -from torchdata.datapipes.utils import StreamWrapper -from torchvision import datapoints -from torchvision._utils import sequence_to_str -from torchvision.prototype import datasets -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import EncodedImage -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE -from torchvision.transforms.v2.utils import is_simple_tensor - - -def assert_samples_equal(*args, msg=None, **kwargs): - error_metas = not_close_error_metas( - *args, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True, **kwargs - ) - if error_metas: - raise error_metas[0].to_error(msg) - - -def extract_datapipes(dp): - return get_all_graph_pipes(traverse_dps(dp)) - - -def consume(iterator): - # Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes - deque(iterator, maxlen=0) - - -def next_consume(iterator): - item = next(iterator) - consume(iterator) - return item - - -@pytest.fixture(autouse=True) -def test_home(mocker, tmp_path): - mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) - mocker.patch("torchvision.prototype.datasets.home", return_value=str(tmp_path)) - yield tmp_path - - -def test_coverage(): - untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys() - if untested_datasets: - raise AssertionError( - f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} " - f"are exposed through `torchvision.prototype.datasets.load()`, but are not tested. " - f"Please add mock data to `test/builtin_dataset_mocks.py`." - ) - - -@pytest.mark.filterwarnings("error") -class TestCommon: - @pytest.mark.parametrize("name", datasets.list_datasets()) - def test_info(self, name): - try: - info = datasets.info(name) - except ValueError: - raise AssertionError("No info available.") from None - - if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())): - raise AssertionError("Info should be a dictionary with string keys.") - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_smoke(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - if not isinstance(dataset, datasets.utils.Dataset): - raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_sample(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - try: - sample = next_consume(iter(dataset)) - except StopIteration: - raise AssertionError("Unable to draw any sample.") from None - except Exception as error: - raise AssertionError("Drawing a sample raised the error above.") from error - - if not isinstance(sample, dict): - raise AssertionError(f"Samples should be dictionaries, but got {type(sample)} instead.") - - if not sample: - raise AssertionError("Sample dictionary is empty.") - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_num_samples(self, dataset_mock, config): - dataset, mock_info = dataset_mock.load(config) - - assert len(list(dataset)) == mock_info["num_samples"] - - @pytest.fixture - def log_session_streams(self): - debug_unclosed_streams = StreamWrapper.debug_unclosed_streams - try: - StreamWrapper.debug_unclosed_streams = True - yield - finally: - StreamWrapper.debug_unclosed_streams = debug_unclosed_streams - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_stream_closing(self, log_session_streams, dataset_mock, config): - def make_msg_and_close(head): - unclosed_streams = [] - for stream in StreamWrapper.session_streams.keys(): - unclosed_streams.append(repr(stream.file_obj)) - stream.close() - unclosed_streams = "\n".join(unclosed_streams) - return f"{head}\n\n{unclosed_streams}" - - if StreamWrapper.session_streams: - raise pytest.UsageError(make_msg_and_close("A previous test did not close the following streams:")) - - dataset, _ = dataset_mock.load(config) - - consume(iter(dataset)) - - if StreamWrapper.session_streams: - raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:")) - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_no_unaccompanied_simple_tensors(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - sample = next_consume(iter(dataset)) - - simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)} - - if simple_tensors and not any( - isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values() - ): - raise AssertionError( - f"The values of key(s) " - f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, " - f"but didn't find any (encoded) image or video." - ) - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_transformable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - dataset = dataset.map(transforms.Identity()) - - consume(iter(dataset)) - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_traversable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - traverse_dps(dataset) - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_serializable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - pickle.dumps(dataset) - - # This has to be a proper function, since lambda's or local functions - # cannot be pickled, but this is a requirement for the DataLoader with - # multiprocessing, i.e. num_workers > 0 - def _collate_fn(self, batch): - return batch - - @pytest.mark.parametrize("num_workers", [0, 1]) - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_data_loader(self, dataset_mock, config, num_workers): - dataset, _ = dataset_mock.load(config) - - dl = DataLoader( - dataset, - batch_size=2, - num_workers=num_workers, - collate_fn=self._collate_fn, - ) - - consume(dl) - - # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also - # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 - # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. - @parametrize_dataset_mocks(DATASET_MOCKS) - @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) - def test_has_annotations(self, dataset_mock, config, annotation_dp_type): - dataset, _ = dataset_mock.load(config) - - if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)): - raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_save_load(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - sample = next_consume(iter(dataset)) - - with io.BytesIO() as buffer: - torch.save(sample, buffer) - buffer.seek(0) - assert_samples_equal(torch.load(buffer), sample) - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_infinite_buffer_size(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - for dp in extract_datapipes(dataset): - if hasattr(dp, "buffer_size"): - # TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is - # resolved - assert dp.buffer_size == INFINITE_BUFFER_SIZE - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_has_length(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - assert len(dataset) > 0 - - -@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) -class TestQMNIST: - def test_extra_label(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - sample = next_consume(iter(dataset)) - for key, type in ( - ("nist_hsf_series", int), - ("nist_writer_id", int), - ("digit_index", int), - ("nist_label", int), - ("global_digit_index", int), - ("duplicate", bool), - ("unused", bool), - ): - assert key in sample and isinstance(sample[key], type) - - -@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"]) -class TestGTSRB: - def test_label_matches_path(self, dataset_mock, config): - # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. - # This test makes sure that they're both the same - if config["split"] != "train": - return - - dataset, _ = dataset_mock.load(config) - - for sample in dataset: - label_from_path = int(Path(sample["path"]).parent.name) - assert sample["label"] == label_from_path - - -@parametrize_dataset_mocks(DATASET_MOCKS["usps"]) -class TestUSPS: - def test_sample_content(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - - for sample in dataset: - assert "image" in sample - assert "label" in sample - - assert isinstance(sample["image"], datapoints.Image) - assert isinstance(sample["label"], Label) - - assert sample["image"].shape == (1, 16, 16) diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py deleted file mode 100644 index 2098ac736ac..00000000000 --- a/test/test_prototype_datasets_utils.py +++ /dev/null @@ -1,302 +0,0 @@ -import gzip -import pathlib -import sys - -import numpy as np -import pytest -import torch -from datasets_utils import make_fake_flo_file, make_tar -from torchdata.datapipes.iter import FileOpener, TarArchiveLoader -from torchvision.datasets._optical_flow import _read_flo as read_flo_ref -from torchvision.datasets.utils import _decompress -from torchvision.prototype.datasets.utils import Dataset, GDriveResource, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import fromfile, read_flo - - -@pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning") -@pytest.mark.parametrize( - ("np_dtype", "torch_dtype", "byte_order"), - [ - (">f4", torch.float32, "big"), - ("<f8", torch.float64, "little"), - ("<i4", torch.int32, "little"), - (">i8", torch.int64, "big"), - ("|u1", torch.uint8, sys.byteorder), - ], -) -@pytest.mark.parametrize("count", (-1, 2)) -@pytest.mark.parametrize("mode", ("rb", "r+b")) -def test_fromfile(tmpdir, np_dtype, torch_dtype, byte_order, count, mode): - path = tmpdir / "data.bin" - rng = np.random.RandomState(0) - rng.randn(5 if count == -1 else count + 1).astype(np_dtype).tofile(path) - - for count_ in (-1, count // 2): - expected = torch.from_numpy(np.fromfile(path, dtype=np_dtype, count=count_).astype(np_dtype[1:])) - - with open(path, mode) as file: - actual = fromfile(file, dtype=torch_dtype, byte_order=byte_order, count=count_) - - torch.testing.assert_close(actual, expected) - - -def test_read_flo(tmpdir): - path = tmpdir / "test.flo" - make_fake_flo_file(3, 4, path) - - with open(path, "rb") as file: - actual = read_flo(file) - - expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False)) - - torch.testing.assert_close(actual, expected) - - -class TestOnlineResource: - class DummyResource(OnlineResource): - def __init__(self, download_fn=None, **kwargs): - super().__init__(**kwargs) - self._download_fn = download_fn - - def _download(self, root): - if self._download_fn is None: - raise pytest.UsageError( - "`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`." - ) - - return self._download_fn(self, root) - - def _make_file(self, root, *, content, name="file.txt"): - file = root / name - with open(file, "w") as fh: - fh.write(content) - - return file - - def _make_folder(self, root, *, name="folder"): - folder = root / name - subfolder = folder / "subfolder" - subfolder.mkdir(parents=True) - - files = {} - for idx, root in enumerate([folder, folder, subfolder]): - content = f"sentinel{idx}" - file = self._make_file(root, name=f"file{idx}.txt", content=content) - files[str(file)] = content - - return folder, files - - def _make_tar(self, root, *, name="archive.tar", remove=True): - folder, files = self._make_folder(root, name=name.split(".")[0]) - archive = make_tar(root, name, folder, remove=remove) - files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()} - return archive, files - - def test_load_file(self, tmp_path): - content = "sentinel" - file = self._make_file(tmp_path, content=content) - - resource = self.DummyResource(file_name=file.name) - - dp = resource.load(tmp_path) - assert isinstance(dp, FileOpener) - - data = list(dp) - assert len(data) == 1 - - path, buffer = data[0] - assert path == str(file) - assert buffer.read().decode() == content - - def test_load_folder(self, tmp_path): - folder, files = self._make_folder(tmp_path) - - resource = self.DummyResource(file_name=folder.name) - - dp = resource.load(tmp_path) - assert isinstance(dp, FileOpener) - assert {path: buffer.read().decode() for path, buffer in dp} == files - - def test_load_archive(self, tmp_path): - archive, files = self._make_tar(tmp_path) - - resource = self.DummyResource(file_name=archive.name) - - dp = resource.load(tmp_path) - assert isinstance(dp, TarArchiveLoader) - assert {path: buffer.read().decode() for path, buffer in dp} == files - - def test_priority_decompressed_gt_raw(self, tmp_path): - # We don't need to actually compress here. Adding the suffix is sufficient - self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz") - file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt") - - resource = self.DummyResource(file_name=file.name) - - dp = resource.load(tmp_path) - path, buffer = next(iter(dp)) - - assert path == str(file) - assert buffer.read().decode() == "decompressed_sentinel" - - def test_priority_extracted_gt_decompressed(self, tmp_path): - archive, _ = self._make_tar(tmp_path, remove=False) - - resource = self.DummyResource(file_name=archive.name) - - dp = resource.load(tmp_path) - # If the archive had been selected, this would be a `TarArchiveReader` - assert isinstance(dp, FileOpener) - - def test_download(self, tmp_path): - download_fn_was_called = False - - def download_fn(resource, root): - nonlocal download_fn_was_called - download_fn_was_called = True - - return self._make_file(root, content="_", name=resource.file_name) - - resource = self.DummyResource( - file_name="file.txt", - download_fn=download_fn, - ) - - resource.load(tmp_path) - - assert download_fn_was_called, "`download_fn()` was never called" - - # This tests the `"decompress"` literal as well as a custom callable - @pytest.mark.parametrize( - "preprocess", - [ - "decompress", - lambda path: _decompress(str(path), remove_finished=True), - ], - ) - def test_preprocess_decompress(self, tmp_path, preprocess): - file_name = "file.txt.gz" - content = "sentinel" - - def download_fn(resource, root): - file = root / resource.file_name - with gzip.open(file, "wb") as fh: - fh.write(content.encode()) - return file - - resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn) - - dp = resource.load(tmp_path) - data = list(dp) - assert len(data) == 1 - - path, buffer = data[0] - assert path == str(tmp_path / file_name).replace(".gz", "") - assert buffer.read().decode() == content - - def test_preprocess_extract(self, tmp_path): - files = None - - def download_fn(resource, root): - nonlocal files - archive, files = self._make_tar(root, name=resource.file_name) - return archive - - resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn) - - dp = resource.load(tmp_path) - assert files is not None, "`download_fn()` was never called" - assert isinstance(dp, FileOpener) - - actual = {path: buffer.read().decode() for path, buffer in dp} - expected = { - path.replace(resource.file_name, resource.file_name.split(".")[0]): content - for path, content in files.items() - } - assert actual == expected - - def test_preprocess_only_after_download(self, tmp_path): - file = self._make_file(tmp_path, content="_") - - def preprocess(path): - raise AssertionError("`preprocess` was called although the file was already present.") - - resource = self.DummyResource( - file_name=file.name, - preprocess=preprocess, - ) - - resource.load(tmp_path) - - -class TestHttpResource: - def test_resolve_to_http(self, mocker): - file_name = "data.tar" - original_url = f"http://downloads.pytorch.org/{file_name}" - - redirected_url = original_url.replace("http", "https") - - sha256_sentinel = "sha256_sentinel" - - def preprocess_sentinel(path): - return path - - original_resource = HttpResource( - original_url, - sha256=sha256_sentinel, - preprocess=preprocess_sentinel, - ) - - mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url) - redirected_resource = original_resource.resolve() - - assert isinstance(redirected_resource, HttpResource) - assert redirected_resource.url == redirected_url - assert redirected_resource.file_name == file_name - assert redirected_resource.sha256 == sha256_sentinel - assert redirected_resource._preprocess is preprocess_sentinel - - def test_resolve_to_gdrive(self, mocker): - file_name = "data.tar" - original_url = f"http://downloads.pytorch.org/{file_name}" - - id_sentinel = "id-sentinel" - redirected_url = f"https://drive.google.com/file/d/{id_sentinel}/view" - - sha256_sentinel = "sha256_sentinel" - - def preprocess_sentinel(path): - return path - - original_resource = HttpResource( - original_url, - sha256=sha256_sentinel, - preprocess=preprocess_sentinel, - ) - - mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url) - redirected_resource = original_resource.resolve() - - assert isinstance(redirected_resource, GDriveResource) - assert redirected_resource.id == id_sentinel - assert redirected_resource.file_name == file_name - assert redirected_resource.sha256 == sha256_sentinel - assert redirected_resource._preprocess is preprocess_sentinel - - -def test_missing_dependency_error(): - class DummyDataset(Dataset): - def __init__(self): - super().__init__(root="root", dependencies=("fake_dependency",)) - - def _resources(self): - pass - - def _datapipe(self, resource_dps): - pass - - def __len__(self): - pass - - with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"): - DummyDataset() diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py deleted file mode 100644 index 6d9f22c1543..00000000000 --- a/test/test_prototype_models.py +++ /dev/null @@ -1,84 +0,0 @@ -import pytest -import test_models as TM -import torch -from common_utils import cpu_and_gpu, set_rng_seed -from torchvision.prototype import models - - -@pytest.mark.parametrize("model_fn", (models.depth.stereo.raft_stereo_base,)) -@pytest.mark.parametrize("model_mode", ("standard", "scripted")) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -def test_raft_stereo(model_fn, model_mode, dev): - # A simple test to make sure the model can do forward pass and jit scriptable - set_rng_seed(0) - - # Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output - # get the idea from test_models.test_raft - corr_pyramid = models.depth.stereo.raft_stereo.CorrPyramid1d(num_levels=2) - corr_block = models.depth.stereo.raft_stereo.CorrBlock1d(num_levels=2, radius=2) - model = model_fn(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev) - - if model_mode == "scripted": - model = torch.jit.script(model) - - img1 = torch.rand(1, 3, 64, 64).to(dev) - img2 = torch.rand(1, 3, 64, 64).to(dev) - num_iters = 3 - - preds = model(img1, img2, num_iters=num_iters) - depth_pred = preds[-1] - - assert len(preds) == num_iters, "Number of predictions should be the same as model.num_iters" - - assert depth_pred.shape == torch.Size( - [1, 1, 64, 64] - ), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}" - - # Test against expected file output - TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2) - - -@pytest.mark.parametrize("model_fn", (models.depth.stereo.crestereo_base,)) -@pytest.mark.parametrize("model_mode", ("standard", "scripted")) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -def test_crestereo(model_fn, model_mode, dev): - set_rng_seed(0) - - model = model_fn().eval().to(dev) - - if model_mode == "scripted": - model = torch.jit.script(model) - - img1 = torch.rand(1, 3, 64, 64).to(dev) - img2 = torch.rand(1, 3, 64, 64).to(dev) - iterations = 3 - - preds = model(img1, img2, flow_init=None, num_iters=iterations) - disparity_pred = preds[-1] - - # all the pyramid levels except the highest res make only half the number of iterations - expected_iterations = (iterations // 2) * (len(model.resolutions) - 1) - expected_iterations += iterations - assert ( - len(preds) == expected_iterations - ), "Number of predictions should be the number of iterations multiplied by the number of pyramid levels" - - assert disparity_pred.shape == torch.Size( - [1, 2, 64, 64] - ), f"Predicted disparity should have the same spatial shape as the input. Inputs shape {img1.shape[2:]}, Prediction shape {disparity_pred.shape[2:]}" - - assert all( - d.shape == torch.Size([1, 2, 64, 64]) for d in preds - ), "All predicted disparities are expected to have the same shape" - - # test a backward pass with a dummy loss as well - preds = torch.stack(preds, dim=0) - targets = torch.ones_like(preds, requires_grad=False) - loss = torch.nn.functional.mse_loss(preds, targets) - - try: - loss.backward() - except Exception as e: - assert False, f"Backward pass failed with an unexpected exception: {e.__class__.__name__} {e}" - - TM._assert_expected(disparity_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py deleted file mode 100644 index 255c3b5c32f..00000000000 --- a/test/test_prototype_transforms.py +++ /dev/null @@ -1,535 +0,0 @@ -import itertools - -import re - -import PIL.Image -import pytest -import torch - -from common_utils import ( - assert_equal, - DEFAULT_EXTRA_DIMS, - make_bounding_box, - make_detection_mask, - make_image, - make_images, - make_segmentation_mask, - make_video, - make_videos, -) - -from prototype_common_utils import make_label, make_one_hot_labels - -from torchvision.datapoints import BoundingBox, BoundingBoxFormat, Image, Mask, Video -from torchvision.prototype import datapoints, transforms -from torchvision.transforms.v2._utils import _convert_fill_arg -from torchvision.transforms.v2.functional import InterpolationMode, pil_to_tensor, to_image_pil -from torchvision.transforms.v2.utils import check_type, is_simple_tensor - -BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] - - -def parametrize(transforms_with_inputs): - return pytest.mark.parametrize( - ("transform", "input"), - [ - pytest.param( - transform, - input, - id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}", - ) - for transform, inputs in transforms_with_inputs - for idx, input in enumerate(inputs) - ], - ) - - -@parametrize( - [ - ( - transform, - [ - dict(inpt=inpt, one_hot_label=one_hot_label) - for inpt, one_hot_label in itertools.product( - itertools.chain( - make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), - make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), - ), - make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), - ) - ], - ) - for transform in [ - transforms.RandomMixup(alpha=1.0), - transforms.RandomCutmix(alpha=1.0), - ] - ] -) -def test_mixup_cutmix(transform, input): - transform(input) - - input_copy = dict(input) - input_copy["path"] = "/path/to/somewhere" - input_copy["num"] = 1234 - transform(input_copy) - - # Check if we raise an error if sample contains bbox or mask or label - err_msg = "does not support PIL images, bounding boxes, masks and plain labels" - input_copy = dict(input) - for unsup_data in [ - make_label(), - make_bounding_box(format="XYXY"), - make_detection_mask(), - make_segmentation_mask(), - ]: - input_copy["unsupported"] = unsup_data - with pytest.raises(TypeError, match=err_msg): - transform(input_copy) - - -class TestSimpleCopyPaste: - def create_fake_image(self, mocker, image_type): - if image_type == PIL.Image.Image: - return PIL.Image.new("RGB", (32, 32), 123) - return mocker.MagicMock(spec=image_type) - - def test__extract_image_targets_assertion(self, mocker): - transform = transforms.SimpleCopyPaste() - - flat_sample = [ - # images, batch size = 2 - self.create_fake_image(mocker, Image), - # labels, bboxes, masks - mocker.MagicMock(spec=datapoints.Label), - mocker.MagicMock(spec=BoundingBox), - mocker.MagicMock(spec=Mask), - # labels, bboxes, masks - mocker.MagicMock(spec=BoundingBox), - mocker.MagicMock(spec=Mask), - ] - - with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"): - transform._extract_image_targets(flat_sample) - - @pytest.mark.parametrize("image_type", [Image, PIL.Image.Image, torch.Tensor]) - @pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) - def test__extract_image_targets(self, image_type, label_type, mocker): - transform = transforms.SimpleCopyPaste() - - flat_sample = [ - # images, batch size = 2 - self.create_fake_image(mocker, image_type), - self.create_fake_image(mocker, image_type), - # labels, bboxes, masks - mocker.MagicMock(spec=label_type), - mocker.MagicMock(spec=BoundingBox), - mocker.MagicMock(spec=Mask), - # labels, bboxes, masks - mocker.MagicMock(spec=label_type), - mocker.MagicMock(spec=BoundingBox), - mocker.MagicMock(spec=Mask), - ] - - images, targets = transform._extract_image_targets(flat_sample) - - assert len(images) == len(targets) == 2 - if image_type == PIL.Image.Image: - torch.testing.assert_close(images[0], pil_to_tensor(flat_sample[0])) - torch.testing.assert_close(images[1], pil_to_tensor(flat_sample[1])) - else: - assert images[0] == flat_sample[0] - assert images[1] == flat_sample[1] - - for target in targets: - for key, type_ in [ - ("boxes", BoundingBox), - ("masks", Mask), - ("labels", label_type), - ]: - assert key in target - assert isinstance(target[key], type_) - assert target[key] in flat_sample - - @pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) - def test__copy_paste(self, label_type): - image = 2 * torch.ones(3, 32, 32) - masks = torch.zeros(2, 32, 32) - masks[0, 3:9, 2:8] = 1 - masks[1, 20:30, 20:30] = 1 - labels = torch.tensor([1, 2]) - blending = True - resize_interpolation = InterpolationMode.BILINEAR - antialias = None - if label_type == datapoints.OneHotLabel: - labels = torch.nn.functional.one_hot(labels, num_classes=5) - target = { - "boxes": BoundingBox( - torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32) - ), - "masks": Mask(masks), - "labels": label_type(labels), - } - - paste_image = 10 * torch.ones(3, 32, 32) - paste_masks = torch.zeros(2, 32, 32) - paste_masks[0, 13:19, 12:18] = 1 - paste_masks[1, 15:19, 1:8] = 1 - paste_labels = torch.tensor([3, 4]) - if label_type == datapoints.OneHotLabel: - paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) - paste_target = { - "boxes": BoundingBox( - torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32) - ), - "masks": Mask(paste_masks), - "labels": label_type(paste_labels), - } - - transform = transforms.SimpleCopyPaste() - random_selection = torch.tensor([0, 1]) - output_image, output_target = transform._copy_paste( - image, target, paste_image, paste_target, random_selection, blending, resize_interpolation, antialias - ) - - assert output_image.unique().tolist() == [2, 10] - assert output_target["boxes"].shape == (4, 4) - torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"]) - torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) - - expected_labels = torch.tensor([1, 2, 3, 4]) - if label_type == datapoints.OneHotLabel: - expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5) - torch.testing.assert_close(output_target["labels"], label_type(expected_labels)) - - assert output_target["masks"].shape == (4, 32, 32) - torch.testing.assert_close(output_target["masks"][:2, :], target["masks"]) - torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"]) - - -class TestFixedSizeCrop: - def test__get_params(self, mocker): - crop_size = (7, 7) - batch_shape = (10,) - spatial_size = (11, 5) - - transform = transforms.FixedSizeCrop(size=crop_size) - - flat_inputs = [ - make_image(size=spatial_size, color_space="RGB"), - make_bounding_box(format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape), - ] - params = transform._get_params(flat_inputs) - - assert params["needs_crop"] - assert params["height"] <= crop_size[0] - assert params["width"] <= crop_size[1] - - assert ( - isinstance(params["is_valid"], torch.Tensor) - and params["is_valid"].dtype is torch.bool - and params["is_valid"].shape == batch_shape - ) - - assert params["needs_pad"] - assert any(pad > 0 for pad in params["padding"]) - - @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) - def test__transform(self, mocker, needs): - fill_sentinel = 12 - padding_mode_sentinel = mocker.MagicMock() - - transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) - transform._transformed_types = (mocker.MagicMock,) - mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - - needs_crop, needs_pad = needs - top_sentinel = mocker.MagicMock() - left_sentinel = mocker.MagicMock() - height_sentinel = mocker.MagicMock() - width_sentinel = mocker.MagicMock() - is_valid = mocker.MagicMock() if needs_crop else None - padding_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", - return_value=dict( - needs_crop=needs_crop, - top=top_sentinel, - left=left_sentinel, - height=height_sentinel, - width=width_sentinel, - is_valid=is_valid, - padding=padding_sentinel, - needs_pad=needs_pad, - ), - ) - - inpt_sentinel = mocker.MagicMock() - - mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop") - mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad") - transform(inpt_sentinel) - - if needs_crop: - mock_crop.assert_called_once_with( - inpt_sentinel, - top=top_sentinel, - left=left_sentinel, - height=height_sentinel, - width=width_sentinel, - ) - else: - mock_crop.assert_not_called() - - if needs_pad: - # If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use - # `MagicMock.assert_called_once_with` and have to perform the checks manually - mock_pad.assert_called_once() - args, kwargs = mock_pad.call_args - if not needs_crop: - assert args[0] is inpt_sentinel - assert args[1] is padding_sentinel - fill_sentinel = _convert_fill_arg(fill_sentinel) - assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) - else: - mock_pad.assert_not_called() - - def test__transform_culling(self, mocker): - batch_size = 10 - spatial_size = (10, 10) - - is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) - mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", - return_value=dict( - needs_crop=True, - top=0, - left=0, - height=spatial_size[0], - width=spatial_size[1], - is_valid=is_valid, - needs_pad=False, - ), - ) - - bounding_boxes = make_bounding_box( - format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) - ) - masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,)) - labels = make_label(extra_dims=(batch_size,)) - - transform = transforms.FixedSizeCrop((-1, -1)) - mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - - output = transform( - dict( - bounding_boxes=bounding_boxes, - masks=masks, - labels=labels, - ) - ) - - assert_equal(output["bounding_boxes"], bounding_boxes[is_valid]) - assert_equal(output["masks"], masks[is_valid]) - assert_equal(output["labels"], labels[is_valid]) - - def test__transform_bounding_box_clamping(self, mocker): - batch_size = 3 - spatial_size = (10, 10) - - mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", - return_value=dict( - needs_crop=True, - top=0, - left=0, - height=spatial_size[0], - width=spatial_size[1], - is_valid=torch.full((batch_size,), fill_value=True), - needs_pad=False, - ), - ) - - bounding_box = make_bounding_box( - format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) - ) - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") - - transform = transforms.FixedSizeCrop((-1, -1)) - mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - - transform(bounding_box) - - mock.assert_called_once() - - -class TestLabelToOneHot: - def test__transform(self): - categories = ["apple", "pear", "pineapple"] - labels = datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories) - transform = transforms.LabelToOneHot() - ohe_labels = transform(labels) - assert isinstance(ohe_labels, datapoints.OneHotLabel) - assert ohe_labels.shape == (4, 3) - assert ohe_labels.categories == labels.categories == categories - - -class TestPermuteDimensions: - @pytest.mark.parametrize( - ("dims", "inverse_dims"), - [ - ( - {Image: (2, 1, 0), Video: None}, - {Image: (2, 1, 0), Video: None}, - ), - ( - {Image: (2, 1, 0), Video: (1, 2, 3, 0)}, - {Image: (2, 1, 0), Video: (3, 0, 1, 2)}, - ), - ], - ) - def test_call(self, dims, inverse_dims): - sample = dict( - image=make_image(), - bounding_box=make_bounding_box(format=BoundingBoxFormat.XYXY), - video=make_video(), - str="str", - int=0, - ) - - transform = transforms.PermuteDimensions(dims) - transformed_sample = transform(sample) - - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] - - if check_type(value, (Image, is_simple_tensor, Video)): - if transform.dims.get(value_type) is not None: - assert transformed_value.permute(inverse_dims[value_type]).equal(value) - assert type(transformed_value) == torch.Tensor - else: - assert transformed_value is value - - @pytest.mark.filterwarnings("error") - def test_plain_tensor_call(self): - tensor = torch.empty((2, 3, 4)) - transform = transforms.PermuteDimensions(dims=(1, 2, 0)) - - assert transform(tensor).shape == (3, 4, 2) - - @pytest.mark.parametrize("other_type", [Image, Video]) - def test_plain_tensor_warning(self, other_type): - with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): - transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) - - -class TestTransposeDimensions: - @pytest.mark.parametrize( - "dims", - [ - (-1, -2), - {Image: (1, 2), Video: None}, - ], - ) - def test_call(self, dims): - sample = dict( - image=make_image(), - bounding_box=make_bounding_box(format=BoundingBoxFormat.XYXY), - video=make_video(), - str="str", - int=0, - ) - - transform = transforms.TransposeDimensions(dims) - transformed_sample = transform(sample) - - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] - - transposed_dims = transform.dims.get(value_type) - if check_type(value, (Image, is_simple_tensor, Video)): - if transposed_dims is not None: - assert transformed_value.transpose(*transposed_dims).equal(value) - assert type(transformed_value) == torch.Tensor - else: - assert transformed_value is value - - @pytest.mark.filterwarnings("error") - def test_plain_tensor_call(self): - tensor = torch.empty((2, 3, 4)) - transform = transforms.TransposeDimensions(dims=(0, 2)) - - assert transform(tensor).shape == (4, 3, 2) - - @pytest.mark.parametrize("other_type", [Image, Video]) - def test_plain_tensor_warning(self, other_type): - with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): - transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) - - -import importlib.machinery -import importlib.util -from pathlib import Path - - -def import_transforms_from_references(reference): - HERE = Path(__file__).parent - PROJECT_ROOT = HERE.parent - - loader = importlib.machinery.SourceFileLoader( - "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py") - ) - spec = importlib.util.spec_from_loader("transforms", loader) - module = importlib.util.module_from_spec(spec) - loader.exec_module(module) - return module - - -det_transforms = import_transforms_from_references("detection") - - -def test_fixed_sized_crop_against_detection_reference(): - def make_datapoints(): - size = (600, 800) - num_objects = 22 - - pil_image = to_image_pil(make_image(size=size, color_space="RGB")) - target = { - "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), - "labels": make_label(extra_dims=(num_objects,), categories=80), - "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), - } - - yield (pil_image, target) - - tensor_image = torch.Tensor(make_image(size=size, color_space="RGB")) - target = { - "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), - "labels": make_label(extra_dims=(num_objects,), categories=80), - "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), - } - - yield (tensor_image, target) - - datapoint_image = make_image(size=size, color_space="RGB") - target = { - "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), - "labels": make_label(extra_dims=(num_objects,), categories=80), - "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), - } - - yield (datapoint_image, target) - - t = transforms.FixedSizeCrop((1024, 1024), fill=0) - t_ref = det_transforms.FixedSizeCrop((1024, 1024), fill=0) - - for dp in make_datapoints(): - # We should use prototype transform first as reference transform performs inplace target update - torch.manual_seed(12) - output = t(dp) - - torch.manual_seed(12) - expected_output = t_ref(*dp) - - assert_equal(expected_output, output) diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py deleted file mode 100644 index 200f5cd9552..00000000000 --- a/torchvision/prototype/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import datapoints, models, transforms, utils diff --git a/torchvision/prototype/datapoints/__init__.py b/torchvision/prototype/datapoints/__init__.py deleted file mode 100644 index 604628b2540..00000000000 --- a/torchvision/prototype/datapoints/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ._label import Label, OneHotLabel diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py deleted file mode 100644 index 7ed2f7522b0..00000000000 --- a/torchvision/prototype/datapoints/_label.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -from typing import Any, Optional, Sequence, Type, TypeVar, Union - -import torch -from torch.utils._pytree import tree_map - -from torchvision.datapoints._datapoint import Datapoint - - -L = TypeVar("L", bound="_LabelBase") - - -class _LabelBase(Datapoint): - categories: Optional[Sequence[str]] - - @classmethod - def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: - label_base = tensor.as_subclass(cls) - label_base.categories = categories - return label_base - - def __new__( - cls: Type[L], - data: Any, - *, - categories: Optional[Sequence[str]] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str, int]] = None, - requires_grad: Optional[bool] = None, - ) -> L: - tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - return cls._wrap(tensor, categories=categories) - - @classmethod - def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L: - return cls._wrap( - tensor, - categories=categories if categories is not None else other.categories, - ) - - @classmethod - def from_category( - cls: Type[L], - category: str, - *, - categories: Sequence[str], - **kwargs: Any, - ) -> L: - return cls(categories.index(category), categories=categories, **kwargs) - - -class Label(_LabelBase): - def to_categories(self) -> Any: - if self.categories is None: - raise RuntimeError("Label does not have categories") - - return tree_map(lambda idx: self.categories[idx], self.tolist()) - - -class OneHotLabel(_LabelBase): - def __new__( - cls, - data: Any, - *, - categories: Optional[Sequence[str]] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, - ) -> OneHotLabel: - one_hot_label = super().__new__( - cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad - ) - - if categories is not None and len(categories) != one_hot_label.shape[-1]: - raise ValueError() - - return one_hot_label diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py deleted file mode 100644 index 848d9135c2f..00000000000 --- a/torchvision/prototype/datasets/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -try: - import torchdata -except ModuleNotFoundError: - raise ModuleNotFoundError( - "`torchvision.prototype.datasets` depends on PyTorch's `torchdata` (https://github.com/pytorch/data). " - "You can install it with `pip install --pre torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu" - ) from None - -from . import utils -from ._home import home - -# Load this last, since some parts depend on the above being loaded first -from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip -from ._folder import from_data_folder, from_image_folder -from ._builtin import * diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py deleted file mode 100644 index f6f06c60a21..00000000000 --- a/torchvision/prototype/datasets/_api.py +++ /dev/null @@ -1,65 +0,0 @@ -import pathlib -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union - -from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.utils import Dataset -from torchvision.prototype.utils._internal import add_suggestion - - -T = TypeVar("T") -D = TypeVar("D", bound=Type[Dataset]) - -BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} - - -def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: - def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: - BUILTIN_INFOS[name] = fn() - return fn - - return wrapper - - -BUILTIN_DATASETS = {} - - -def register_dataset(name: str) -> Callable[[D], D]: - def wrapper(dataset_cls: D) -> D: - BUILTIN_DATASETS[name] = dataset_cls - return dataset_cls - - return wrapper - - -def list_datasets() -> List[str]: - return sorted(BUILTIN_DATASETS.keys()) - - -def find(dct: Dict[str, T], name: str) -> T: - name = name.lower() - try: - return dct[name] - except KeyError as error: - raise ValueError( - add_suggestion( - f"Unknown dataset '{name}'.", - word=name, - possibilities=dct.keys(), - alternative_hint=lambda _: ( - "You can use torchvision.datasets.list_datasets() to get a list of all available datasets." - ), - ) - ) from error - - -def info(name: str) -> Dict[str, Any]: - return find(BUILTIN_INFOS, name) - - -def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset: - dataset_cls = find(BUILTIN_DATASETS, name) - - if root is None: - root = pathlib.Path(home()) / name - - return dataset_cls(root, **config) diff --git a/torchvision/prototype/datasets/_builtin/README.md b/torchvision/prototype/datasets/_builtin/README.md deleted file mode 100644 index 05d61c6870e..00000000000 --- a/torchvision/prototype/datasets/_builtin/README.md +++ /dev/null @@ -1,340 +0,0 @@ -# How to add new built-in prototype datasets - -As the name implies, the datasets are still in a prototype state and thus subject to rapid change. This in turn means -that this document will also change a lot. - -If you hit a blocker while adding a dataset, please have a look at another similar dataset to see how it is implemented -there. If you can't resolve it yourself, feel free to send a draft PR in order for us to help you out. - -Finally, `from torchvision.prototype import datasets` is implied below. - -## Implementation - -Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin` -that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that -module create a class that inherits from `datasets.utils.Dataset` and overwrites four methods that will be discussed in -detail below: - -```python -import pathlib -from typing import Any, BinaryIO, Dict, List, Tuple, Union - -from torchdata.datapipes.iter import IterDataPipe -from torchvision.prototype.datasets.utils import Dataset, OnlineResource - -from .._api import register_dataset, register_info - -NAME = "my-dataset" - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict( - ... - ) - -@register_dataset(NAME) -class MyDataset(Dataset): - def __init__(self, root: Union[str, pathlib.Path], *, ..., skip_integrity_check: bool = False) -> None: - ... - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - ... - - def _datapipe(self, resource_dps: List[IterDataPipe[Tuple[str, BinaryIO]]]) -> IterDataPipe[Dict[str, Any]]: - ... - - def __len__(self) -> int: - ... -``` - -In addition to the dataset, you also need to implement an `_info()` function that takes no arguments and returns a -dictionary of static information. The most common use case is to provide human-readable categories. -[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories. - -Finally, both the dataset class and the info function need to be registered on the API with the respective decorators. -With that they are loadable through `datasets.load("my-dataset")` and `datasets.info("my-dataset")`, respectively. - -### `__init__(self, root, *, ..., skip_integrity_check = False)` - -Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the -base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as -setting them as instance attributes has to happen before the call of `super().__init__(...)`, because that will invoke -the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with -an underscore. - -If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base -class constructor, e.g. `super().__init__(..., dependencies=("scipy",))`. Their availability will be automatically -checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to -avoid missing dependencies at import time. - -### `_resources(self)` - -Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset can be -build. The download will happen automatically. - -Currently, the following `OnlineResource`'s are supported: - -- `HttpResource`: Used for files that are directly exposed through HTTP(s) and only requires the URL. -- `GDriveResource`: Used for files that are hosted on GDrive and requires the GDrive ID as well as the `file_name`. -- `ManualDownloadResource`: Used files are not publicly accessible and requires instructions how to download them - manually. If the file does not exist, an error will be raised with the supplied instructions. -- `KaggleDownloadResource`: Used for files that are available on Kaggle. This inherits from `ManualDownloadResource`. - -Although optional in general, all resources used in the built-in datasets should comprise -[SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It will be automatically checked after the -download. You can compute the checksum with system utilities e.g `sha256-sum`, or this snippet: - -```python -import hashlib - -def sha256sum(path, chunk_size=1024 * 1024): - checksum = hashlib.sha256() - with open(path, "rb") as f: - for chunk in iter(lambda: f.read(chunk_size), b""): - checksum.update(chunk) - print(checksum.hexdigest()) -``` - -### `_datapipe(self, resource_dps)` - -This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared -to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone -that is working with them rather than on them, `IterDataPipe`'s behave just as generators, i.e. you can't do anything -with them besides iterating. - -Of course, there are some common building blocks that should suffice in 95% of the cases. The most used are: - -- `Mapper`: Apply a callable to every item in the datapipe. -- `Filter`: Keep only items that satisfy a condition. -- `Demultiplexer`: Split a datapipe into multiple ones. -- `IterKeyZipper`: Merge two datapipes into one. - -All of them can be imported `from torchdata.datapipes.iter`. In addition, use `functools.partial` in case a callable -needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated -to add one. See the MNIST or CelebA datasets for example. - -`_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return -value of `_resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain -tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one -of such tuples for the file specified by the resource. - -Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and -`Grouper`. There are two issues with that: - -1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely. -2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime. - -Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than -trying to zip already loaded images. - -There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and -`hint_sharding`. As the name implies they only hint at a location in the datapipe graph where shuffling and sharding -should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` -and are required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`. - -Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the -names (yet!). - -### `__len__` - -This returns an integer denoting the number of samples that can be drawn from the dataset. Please use -[underscores](https://peps.python.org/pep-0515/) after every three digits starting from the right to enhance the -readability. For example, `1_281_167` vs. `1281167`. - -If there are only two different numbers, a simple `if` / `else` is fine: - -```py -def __len__(self): - return 12_345 if self._split == "train" else 6_789 -``` - -If there are more options, using a dictionary usually is the most readable option: - -```py -def __len__(self): - return { - "train": 3, - "val": 2, - "test": 1, - }[self._split] -``` - -If the number of samples depends on more than one parameter, you can use tuples as dictionary keys: - -```py -def __len__(self): - return { - ("train", "bar"): 4, - ("train", "baz"): 3, - ("test", "bar"): 2, - ("test", "baz"): 1, - }[(self._split, self._foo)] -``` - -The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the -development process. Since it is an `@abstractmethod` you still have to implement it from the start. The canonical way -is to define a dummy method like - -```py -def __len__(self): - return 1 -``` - -and only fill it with the correct data if the implementation is otherwise finished. -[See below](#how-do-i-compute-the-number-of-samples) for a possible way to compute the number of samples. - -## Tests - -To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data. -This mock-up should resemble the original data as close as necessary, while containing only few examples. - -To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the -same name as you have used in `@register_info` and `@register_dataset`. This function is called "mock data function". -Decorate it with `@register_mock(configs=[dict(...), ...])`. Each dictionary denotes one configuration that the dataset -will be loaded with, e.g. `datasets.load("my-dataset", **config)`. For the most common case of a product of all options, -you can use the `combinations_grid()` helper function, e.g. -`configs=combinations_grid(split=("train", "test"), foo=("bar", "baz"))`. - -In case the name of the dataset includes hyphens `-`, replace them with underscores `_` in the function name and pass -the `name` parameter to `@register_mock` - -```py -# this is defined in torchvision/prototype/datasets/_builtin -@register_dataset("my-dataset") -class MyDataset(Dataset): - ... - -@register_mock(name="my-dataset", configs=...) -def my_dataset(root, config): - ... -``` - -The mock data function receives two arguments: - -- `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data - needs to be placed. -- `config`: The configuration to generate the data for. This is one of the dictionaries defined in - `@register_mock(configs=...)` - -The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if -the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of -the current `config`. Although this seems odd at first, this is important. Consider the following original data setup: - -``` -root -├── test -│ ├── test_image0.jpg -│ ... -└── train - ├── train_image0.jpg - ... -``` - -For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to -load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is -present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in -`_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data for -the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real data. - -For datasets that are ported from the old API, we already have some mock data in -[`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there -and have a look at the `inject_fake_data` function. There are a few differences though: - -- `tmp_dir` corresponds to `root`, but is a `str` rather than a - [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like - `folder = pathlib.Path(tmp_dir)`. This is not needed. -- The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the - new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files - specified in the dataset. -- As explained in the paragraph above, the generated data is often "incomplete" and only valid for given the config. - Make sure you follow the instructions above. - -The function should return an integer indicating the number of samples in the dataset for the current `config`. -Preferably, this number should be different for different `config`'s to have more confidence in the dataset -implementation. - -Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets.py -k {name}`. - -## FAQ - -### How do I start? - -Get the skeleton of your dataset class ready with all 4 methods. For `_datapipe()`, you can just do -`return resources_dp[0]` to get started. Then import the dataset class in -`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset, and it will be -instantiable via `datasets.load("mydataset")`. On a separate script, try something like - -```py -from torchvision.prototype import datasets - -dataset = datasets.load("mydataset") -for sample in dataset: - print(sample) # this is the content of an item in datapipe returned by _datapipe() - break -# Or you can also inspect the sample in a debugger -``` - -This will give you an idea of what the first datapipe in `resources_dp` contains. You can also do that with -`resources_dp[1]` or `resources_dp[2]` (etc.) if they exist. Then follow the instructions above to manipulate these -datapipes and return the appropriate dictionary format. - -### How do I handle a dataset that defines many categories? - -As a rule of thumb, `categories` in the info dictionary should only be set manually for ten categories or fewer. If more -categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a -category. To load such a file, use the `from torchvision.prototype.datasets.utils._internal import read_categories_file` -function and pass it `$NAME`. - -In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where -each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. The method -should return a sequence of strings representing the category names. In the method body, you'll have to manually load -the resources, e.g. - -```py -resources = self._resources() -dp = resources[0].load(self._root) -``` - -Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes -sense and afterwards materialize the data with `next(iter(dp))` or `list(dp)` and proceed with that. - -To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`. - -### What if a resource file forms an I/O bottleneck? - -In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if -the performance hit becomes significant, the archives can still be preprocessed. `OnlineResource` accepts the -`preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be -preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also -accepts `"decompress"` and `"extract"` to handle these common scenarios. - -### How do I compute the number of samples? - -Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way -than to iterate over the dataset and count the number of samples: - -```py -import itertools -from torchvision.prototype import datasets - - -def combinations_grid(**kwargs): - return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] - - -# If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there -configs = combinations_grid(split=("train", "test"), foo=("bar", "baz")) - -for config in configs: - dataset = datasets.load("my-dataset", **config) - - num_samples = 0 - for _ in dataset: - num_samples += 1 - - print(", ".join(f"{key}={value}" for key, value in config.items()), num_samples) -``` - -To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation -files. diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py deleted file mode 100644 index d84e9af9fc4..00000000000 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .caltech import Caltech101, Caltech256 -from .celeba import CelebA -from .cifar import Cifar10, Cifar100 -from .clevr import CLEVR -from .coco import Coco -from .country211 import Country211 -from .cub200 import CUB200 -from .dtd import DTD -from .eurosat import EuroSAT -from .fer2013 import FER2013 -from .food101 import Food101 -from .gtsrb import GTSRB -from .imagenet import ImageNet -from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST -from .oxford_iiit_pet import OxfordIIITPet -from .pcam import PCAM -from .sbd import SBD -from .semeion import SEMEION -from .stanford_cars import StanfordCars -from .svhn import SVHN -from .usps import USPS -from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py deleted file mode 100644 index f3882361638..00000000000 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ /dev/null @@ -1,212 +0,0 @@ -import pathlib -import re -from typing import Any, BinaryIO, Dict, List, Tuple, Union - -import numpy as np - -import torch -from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper -from torchvision.datapoints import BoundingBox -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - read_categories_file, - read_mat, -) - -from .._api import register_dataset, register_info - - -@register_info("caltech101") -def _caltech101_info() -> Dict[str, Any]: - return dict(categories=read_categories_file("caltech101")) - - -@register_dataset("caltech101") -class Caltech101(Dataset): - """ - - **homepage**: https://data.caltech.edu/records/20086 - - **dependencies**: - - <scipy `https://scipy.org/`>_ - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - skip_integrity_check: bool = False, - ) -> None: - self._categories = _caltech101_info()["categories"] - - super().__init__( - root, - dependencies=("scipy",), - skip_integrity_check=skip_integrity_check, - ) - - def _resources(self) -> List[OnlineResource]: - images = GDriveResource( - "137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp", - file_name="101_ObjectCategories.tar.gz", - sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", - preprocess="decompress", - ) - anns = GDriveResource( - "175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m", - file_name="Annotations.tar", - sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8", - ) - return [images, anns] - - _IMAGES_NAME_PATTERN = re.compile(r"image_(?P<id>\d+)[.]jpg") - _ANNS_NAME_PATTERN = re.compile(r"annotation_(?P<id>\d+)[.]mat") - _ANNS_CATEGORY_MAP = { - "Faces_2": "Faces", - "Faces_3": "Faces_easy", - "Motorbikes_16": "Motorbikes", - "Airplanes_Side_2": "airplanes", - } - - def _is_not_background_image(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.parent.name != "BACKGROUND_Google" - - def _is_ann(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return bool(self._ANNS_NAME_PATTERN.match(path.name)) - - def _images_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: - path = pathlib.Path(data[0]) - - category = path.parent.name - id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr] - - return category, id - - def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: - path = pathlib.Path(data[0]) - - category = path.parent.name - if category in self._ANNS_CATEGORY_MAP: - category = self._ANNS_CATEGORY_MAP[category] - - id = self._ANNS_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr] - - return category, id - - def _prepare_sample( - self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]] - ) -> Dict[str, Any]: - key, (image_data, ann_data) = data - category, _ = key - image_path, image_buffer = image_data - ann_path, ann_buffer = ann_data - - image = EncodedImage.from_file(image_buffer) - ann = read_mat(ann_buffer) - - return dict( - label=Label.from_category(category, categories=self._categories), - image_path=image_path, - image=image, - ann_path=ann_path, - bounding_box=BoundingBox( - ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], - format="xyxy", - spatial_size=image.spatial_size, - ), - contour=torch.as_tensor(ann["obj_contour"].T), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - images_dp, anns_dp = resource_dps - - images_dp = Filter(images_dp, self._is_not_background_image) - images_dp = hint_shuffling(images_dp) - images_dp = hint_sharding(images_dp) - - anns_dp = Filter(anns_dp, self._is_ann) - - dp = IterKeyZipper( - images_dp, - anns_dp, - key_fn=self._images_key_fn, - ref_key_fn=self._anns_key_fn, - buffer_size=INFINITE_BUFFER_SIZE, - keep_key=True, - ) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 8677 - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, self._is_not_background_image) - - return sorted({pathlib.Path(path).parent.name for path, _ in dp}) - - -@register_info("caltech256") -def _caltech256_info() -> Dict[str, Any]: - return dict(categories=read_categories_file("caltech256")) - - -@register_dataset("caltech256") -class Caltech256(Dataset): - """ - - **homepage**: https://data.caltech.edu/records/20087 - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - skip_integrity_check: bool = False, - ) -> None: - self._categories = _caltech256_info()["categories"] - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - return [ - GDriveResource( - "1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK", - file_name="256_ObjectCategories.tar", - sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e", - ) - ] - - def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.name != "RENAME2" - - def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: - path, buffer = data - - return dict( - path=path, - image=EncodedImage.from_file(buffer), - label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = Filter(dp, self._is_not_rogue_file) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 30607 - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[0].load(self._root) - dir_names = {pathlib.Path(path).parent.name for path, _ in dp} - - return [name.split(".")[1] for name in sorted(dir_names)] diff --git a/torchvision/prototype/datasets/_builtin/caltech101.categories b/torchvision/prototype/datasets/_builtin/caltech101.categories deleted file mode 100644 index d5c18654b4e..00000000000 --- a/torchvision/prototype/datasets/_builtin/caltech101.categories +++ /dev/null @@ -1,101 +0,0 @@ -Faces -Faces_easy -Leopards -Motorbikes -accordion -airplanes -anchor -ant -barrel -bass -beaver -binocular -bonsai -brain -brontosaurus -buddha -butterfly -camera -cannon -car_side -ceiling_fan -cellphone -chair -chandelier -cougar_body -cougar_face -crab -crayfish -crocodile -crocodile_head -cup -dalmatian -dollar_bill -dolphin -dragonfly -electric_guitar -elephant -emu -euphonium -ewer -ferry -flamingo -flamingo_head -garfield -gerenuk -gramophone -grand_piano -hawksbill -headphone -hedgehog -helicopter -ibis -inline_skate -joshua_tree -kangaroo -ketch -lamp -laptop -llama -lobster -lotus -mandolin -mayfly -menorah -metronome -minaret -nautilus -octopus -okapi -pagoda -panda -pigeon -pizza -platypus -pyramid -revolver -rhino -rooster -saxophone -schooner -scissors -scorpion -sea_horse -snoopy -soccer_ball -stapler -starfish -stegosaurus -stop_sign -strawberry -sunflower -tick -trilobite -umbrella -watch -water_lilly -wheelchair -wild_cat -windsor_chair -wrench -yin_yang diff --git a/torchvision/prototype/datasets/_builtin/caltech256.categories b/torchvision/prototype/datasets/_builtin/caltech256.categories deleted file mode 100644 index 82128efba97..00000000000 --- a/torchvision/prototype/datasets/_builtin/caltech256.categories +++ /dev/null @@ -1,257 +0,0 @@ -ak47 -american-flag -backpack -baseball-bat -baseball-glove -basketball-hoop -bat -bathtub -bear -beer-mug -billiards -binoculars -birdbath -blimp -bonsai-101 -boom-box -bowling-ball -bowling-pin -boxing-glove -brain-101 -breadmaker -buddha-101 -bulldozer -butterfly -cactus -cake -calculator -camel -cannon -canoe -car-tire -cartman -cd -centipede -cereal-box -chandelier-101 -chess-board -chimp -chopsticks -cockroach -coffee-mug -coffin -coin -comet -computer-keyboard -computer-monitor -computer-mouse -conch -cormorant -covered-wagon -cowboy-hat -crab-101 -desk-globe -diamond-ring -dice -dog -dolphin-101 -doorknob -drinking-straw -duck -dumb-bell -eiffel-tower -electric-guitar-101 -elephant-101 -elk -ewer-101 -eyeglasses -fern -fighter-jet -fire-extinguisher -fire-hydrant -fire-truck -fireworks -flashlight -floppy-disk -football-helmet -french-horn -fried-egg -frisbee -frog -frying-pan -galaxy -gas-pump -giraffe -goat -golden-gate-bridge -goldfish -golf-ball -goose -gorilla -grand-piano-101 -grapes -grasshopper -guitar-pick -hamburger -hammock -harmonica -harp -harpsichord -hawksbill-101 -head-phones -helicopter-101 -hibiscus -homer-simpson -horse -horseshoe-crab -hot-air-balloon -hot-dog -hot-tub -hourglass -house-fly -human-skeleton -hummingbird -ibis-101 -ice-cream-cone -iguana -ipod -iris -jesus-christ -joy-stick -kangaroo-101 -kayak -ketch-101 -killer-whale -knife -ladder -laptop-101 -lathe -leopards-101 -license-plate -lightbulb -light-house -lightning -llama-101 -mailbox -mandolin -mars -mattress -megaphone -menorah-101 -microscope -microwave -minaret -minotaur -motorbikes-101 -mountain-bike -mushroom -mussels -necktie -octopus -ostrich -owl -palm-pilot -palm-tree -paperclip -paper-shredder -pci-card -penguin -people -pez-dispenser -photocopier -picnic-table -playing-card -porcupine -pram -praying-mantis -pyramid -raccoon -radio-telescope -rainbow -refrigerator -revolver-101 -rifle -rotary-phone -roulette-wheel -saddle -saturn -school-bus -scorpion-101 -screwdriver -segway -self-propelled-lawn-mower -sextant -sheet-music -skateboard -skunk -skyscraper -smokestack -snail -snake -sneaker -snowmobile -soccer-ball -socks -soda-can -spaghetti -speed-boat -spider -spoon -stained-glass -starfish-101 -steering-wheel -stirrups -sunflower-101 -superman -sushi -swan -swiss-army-knife -sword -syringe -tambourine -teapot -teddy-bear -teepee -telephone-box -tennis-ball -tennis-court -tennis-racket -theodolite -toaster -tomato -tombstone -top-hat -touring-bike -tower-pisa -traffic-light -treadmill -triceratops -tricycle -trilobite-101 -tripod -t-shirt -tuning-fork -tweezer -umbrella-101 -unicorn -vcr -video-projector -washing-machine -watch-101 -waterfall -watermelon -welding-mask -wheelbarrow -windmill -wine-bottle -xylophone -yarmulke -yo-yo -zebra -airplanes-101 -car-side-101 -faces-easy-101 -greyhound -tennis-shoes -toad -clutter diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py deleted file mode 100644 index 2c819468778..00000000000 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ /dev/null @@ -1,200 +0,0 @@ -import csv -import pathlib -from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union - -import torch -from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper -from torchvision.datapoints import BoundingBox -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, -) - -from .._api import register_dataset, register_info - -csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) - - -class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): - def __init__( - self, - datapipe: IterDataPipe[Tuple[Any, BinaryIO]], - *, - fieldnames: Optional[Sequence[str]] = None, - ) -> None: - self.datapipe = datapipe - self.fieldnames = fieldnames - - def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: - for _, file in self.datapipe: - try: - lines = (line.decode() for line in file) - - if self.fieldnames: - fieldnames = self.fieldnames - else: - # The first row is skipped, because it only contains the number of samples - next(lines) - - # Empty field names are filtered out, because some files have an extra white space after the header - # line, which is recognized as extra column - fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name] - # Some files do not include a label for the image ID column - if fieldnames[0] != "image_id": - fieldnames.insert(0, "image_id") - - for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"): - yield line.pop("image_id"), line - finally: - file.close() - - -NAME = "celeba" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict() - - -@register_dataset(NAME) -class CelebA(Dataset): - """ - - **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - splits = GDriveResource( - "0B7EVK8r0v71pY0NSMzRuSXJEVkk", - sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7", - file_name="list_eval_partition.txt", - ) - images = GDriveResource( - "0B7EVK8r0v71pZjFTYXZWM3FlRnM", - sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74", - file_name="img_align_celeba.zip", - ) - identities = GDriveResource( - "1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", - sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0", - file_name="identity_CelebA.txt", - ) - attributes = GDriveResource( - "0B7EVK8r0v71pblRyaVFSWGxPY0U", - sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0", - file_name="list_attr_celeba.txt", - ) - bounding_boxes = GDriveResource( - "0B7EVK8r0v71pbThiMVRxWXZ4dU0", - sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b", - file_name="list_bbox_celeba.txt", - ) - landmarks = GDriveResource( - "0B7EVK8r0v71pd0FJY3Blby1HUTQ", - sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b", - file_name="list_landmarks_align_celeba.txt", - ) - return [splits, images, identities, attributes, bounding_boxes, landmarks] - - def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool: - split_id = { - "train": "0", - "val": "1", - "test": "2", - }[self._split] - return data[1]["split_id"] == split_id - - def _prepare_sample( - self, - data: Tuple[ - Tuple[str, Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]], - Tuple[ - Tuple[str, Dict[str, str]], - Tuple[str, Dict[str, str]], - Tuple[str, Dict[str, str]], - Tuple[str, Dict[str, str]], - ], - ], - ) -> Dict[str, Any]: - split_and_image_data, ann_data = data - _, (_, image_data) = split_and_image_data - path, buffer = image_data - - image = EncodedImage.from_file(buffer) - (_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data - - return dict( - path=path, - image=image, - identity=Label(int(identity["identity"])), - attributes={attr: value == "1" for attr, value in attributes.items()}, - bounding_box=BoundingBox( - [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], - format="xywh", - spatial_size=image.spatial_size, - ), - landmarks={ - landmark: torch.tensor((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) - for landmark in {key[:-2] for key in landmarks.keys()} - }, - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps - - splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) - splits_dp = Filter(splits_dp, self._filter_split) - splits_dp = hint_shuffling(splits_dp) - splits_dp = hint_sharding(splits_dp) - - anns_dp = Zipper( - *[ - CelebACSVParser(dp, fieldnames=fieldnames) - for dp, fieldnames in ( - (identities_dp, ("image_id", "identity")), - (attributes_dp, None), - (bounding_boxes_dp, None), - (landmarks_dp, None), - ) - ] - ) - - dp = IterKeyZipper( - splits_dp, - images_dp, - key_fn=getitem(0), - ref_key_fn=path_accessor("name"), - buffer_size=INFINITE_BUFFER_SIZE, - keep_key=True, - ) - dp = IterKeyZipper( - dp, - anns_dp, - key_fn=getitem(0), - ref_key_fn=getitem(0, 0), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - "train": 162_770, - "val": 19_867, - "test": 19_962, - }[self._split] diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py deleted file mode 100644 index 7d178291992..00000000000 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ /dev/null @@ -1,142 +0,0 @@ -import abc -import io -import pathlib -import pickle -from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, Union - -import numpy as np -from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper -from torchvision.datapoints import Image -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, - path_comparator, - read_categories_file, -) - -from .._api import register_dataset, register_info - - -class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): - def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: - self.datapipe = datapipe - self.labels_key = labels_key - - def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]: - for mapping in self.datapipe: - image_arrays = mapping["data"].reshape((-1, 3, 32, 32)) - category_idcs = mapping[self.labels_key] - yield from iter(zip(image_arrays, category_idcs)) - - -class _CifarBase(Dataset): - _FILE_NAME: str - _SHA256: str - _LABELS_KEY: str - _META_FILE_NAME: str - _CATEGORIES_KEY: str - _categories: List[str] - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test")) - super().__init__(root, skip_integrity_check=skip_integrity_check) - - @abc.abstractmethod - def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]: - pass - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}", - sha256=self._SHA256, - ) - ] - - def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: - _, file = data - content = cast(Dict[str, Any], pickle.load(file, encoding="latin1")) - file.close() - return content - - def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: - image_array, category_idx = data - return dict( - image=Image(image_array), - label=Label(category_idx, categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = Filter(dp, self._is_data_file) - dp = Mapper(dp, self._unpickle) - dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 50_000 if self._split == "train" else 10_000 - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) - dp = Mapper(dp, self._unpickle) - - return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) - - -@register_info("cifar10") -def _cifar10_info() -> Dict[str, Any]: - return dict(categories=read_categories_file("cifar10")) - - -@register_dataset("cifar10") -class Cifar10(_CifarBase): - """ - - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html - """ - - _FILE_NAME = "cifar-10-python.tar.gz" - _SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" - _LABELS_KEY = "labels" - _META_FILE_NAME = "batches.meta" - _CATEGORIES_KEY = "label_names" - _categories = _cifar10_info()["categories"] - - def _is_data_file(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.name.startswith("data" if self._split == "train" else "test") - - -@register_info("cifar100") -def _cifar100_info() -> Dict[str, Any]: - return dict(categories=read_categories_file("cifar100")) - - -@register_dataset("cifar100") -class Cifar100(_CifarBase): - """ - - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html - """ - - _FILE_NAME = "cifar-100-python.tar.gz" - _SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" - _LABELS_KEY = "fine_labels" - _META_FILE_NAME = "meta" - _CATEGORIES_KEY = "fine_label_names" - _categories = _cifar100_info()["categories"] - - def _is_data_file(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.name == self._split diff --git a/torchvision/prototype/datasets/_builtin/cifar10.categories b/torchvision/prototype/datasets/_builtin/cifar10.categories deleted file mode 100644 index fa30c22b95d..00000000000 --- a/torchvision/prototype/datasets/_builtin/cifar10.categories +++ /dev/null @@ -1,10 +0,0 @@ -airplane -automobile -bird -cat -deer -dog -frog -horse -ship -truck diff --git a/torchvision/prototype/datasets/_builtin/cifar100.categories b/torchvision/prototype/datasets/_builtin/cifar100.categories deleted file mode 100644 index 7f7bf51d1ab..00000000000 --- a/torchvision/prototype/datasets/_builtin/cifar100.categories +++ /dev/null @@ -1,100 +0,0 @@ -apple -aquarium_fish -baby -bear -beaver -bed -bee -beetle -bicycle -bottle -bowl -boy -bridge -bus -butterfly -camel -can -castle -caterpillar -cattle -chair -chimpanzee -clock -cloud -cockroach -couch -crab -crocodile -cup -dinosaur -dolphin -elephant -flatfish -forest -fox -girl -hamster -house -kangaroo -keyboard -lamp -lawn_mower -leopard -lion -lizard -lobster -man -maple_tree -motorcycle -mountain -mouse -mushroom -oak_tree -orange -orchid -otter -palm_tree -pear -pickup_truck -pine_tree -plain -plate -poppy -porcupine -possum -rabbit -raccoon -ray -road -rocket -rose -sea -seal -shark -shrew -skunk -skyscraper -snail -snake -spider -squirrel -streetcar -sunflower -sweet_pepper -table -tank -telephone -television -tiger -tractor -train -trout -tulip -turtle -wardrobe -whale -willow_tree -wolf -woman -worm diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py deleted file mode 100644 index e282635684e..00000000000 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ /dev/null @@ -1,107 +0,0 @@ -import pathlib -from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - path_comparator, -) - -from .._api import register_dataset, register_info - -NAME = "clevr" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict() - - -@register_dataset(NAME) -class CLEVR(Dataset): - """ - - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/ - """ - - def __init__( - self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - archive = HttpResource( - "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip", - sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1", - ) - return [archive] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - if path.parents[1].name == "images": - return 0 - elif path.parent.name == "scenes": - return 1 - else: - return None - - def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool: - key, _ = data - return key == "scenes" - - def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, BinaryIO], None]: - return data, None - - def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]: - image_data, scenes_data = data - path, buffer = image_data - - return dict( - path=path, - image=EncodedImage.from_file(buffer), - label=Label(len(scenes_data["objects"])) if scenes_data else None, - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] - images_dp, scenes_dp = Demultiplexer( - archive_dp, - 2, - self._classify_archive, - drop_none=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - images_dp = Filter(images_dp, path_comparator("parent.name", self._split)) - images_dp = hint_shuffling(images_dp) - images_dp = hint_sharding(images_dp) - - if self._split != "test": - scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json")) - scenes_dp = JsonParser(scenes_dp) - scenes_dp = Mapper(scenes_dp, getitem(1, "scenes")) - scenes_dp = UnBatcher(scenes_dp) - - dp = IterKeyZipper( - images_dp, - scenes_dp, - key_fn=path_accessor("name"), - ref_key_fn=getitem("image_filename"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - else: - for _, file in scenes_dp: - file.close() - dp = Mapper(images_dp, self._add_empty_anns) - - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 70_000 if self._split == "train" else 15_000 diff --git a/torchvision/prototype/datasets/_builtin/coco.categories b/torchvision/prototype/datasets/_builtin/coco.categories deleted file mode 100644 index 27e612f6d7d..00000000000 --- a/torchvision/prototype/datasets/_builtin/coco.categories +++ /dev/null @@ -1,91 +0,0 @@ -__background__,N/A -person,person -bicycle,vehicle -car,vehicle -motorcycle,vehicle -airplane,vehicle -bus,vehicle -train,vehicle -truck,vehicle -boat,vehicle -traffic light,outdoor -fire hydrant,outdoor -N/A,N/A -stop sign,outdoor -parking meter,outdoor -bench,outdoor -bird,animal -cat,animal -dog,animal -horse,animal -sheep,animal -cow,animal -elephant,animal -bear,animal -zebra,animal -giraffe,animal -N/A,N/A -backpack,accessory -umbrella,accessory -N/A,N/A -N/A,N/A -handbag,accessory -tie,accessory -suitcase,accessory -frisbee,sports -skis,sports -snowboard,sports -sports ball,sports -kite,sports -baseball bat,sports -baseball glove,sports -skateboard,sports -surfboard,sports -tennis racket,sports -bottle,kitchen -N/A,N/A -wine glass,kitchen -cup,kitchen -fork,kitchen -knife,kitchen -spoon,kitchen -bowl,kitchen -banana,food -apple,food -sandwich,food -orange,food -broccoli,food -carrot,food -hot dog,food -pizza,food -donut,food -cake,food -chair,furniture -couch,furniture -potted plant,furniture -bed,furniture -N/A,N/A -dining table,furniture -N/A,N/A -N/A,N/A -toilet,furniture -N/A,N/A -tv,electronic -laptop,electronic -mouse,electronic -remote,electronic -keyboard,electronic -cell phone,electronic -microwave,appliance -oven,appliance -toaster,appliance -sink,appliance -refrigerator,appliance -N/A,N/A -book,indoor -clock,indoor -vase,indoor -scissors,indoor -teddy bear,indoor -hair drier,indoor -toothbrush,indoor diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py deleted file mode 100644 index 6616b4e3491..00000000000 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ /dev/null @@ -1,274 +0,0 @@ -import pathlib -import re -from collections import defaultdict, OrderedDict -from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union - -import torch -from torchdata.datapipes.iter import ( - Demultiplexer, - Filter, - Grouper, - IterDataPipe, - IterKeyZipper, - JsonParser, - Mapper, - UnBatcher, -) -from torchvision.datapoints import BoundingBox, Mask -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - MappingIterator, - path_accessor, - read_categories_file, -) - -from .._api import register_dataset, register_info - - -NAME = "coco" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - categories, super_categories = zip(*read_categories_file(NAME)) - return dict(categories=categories, super_categories=super_categories) - - -@register_dataset(NAME) -class Coco(Dataset): - """ - - **homepage**: https://cocodataset.org/ - - **dependencies**: - - <pycocotools `https://github.com/cocodataset/cocoapi`>_ - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - year: str = "2017", - annotations: Optional[str] = "instances", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "val"}) - self._year = self._verify_str_arg(year, "year", {"2017", "2014"}) - self._annotations = ( - self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys()) - if annotations is not None - else None - ) - - info = _info() - categories, super_categories = info["categories"], info["super_categories"] - self._categories = categories - self._category_to_super_category = dict(zip(categories, super_categories)) - - super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check) - - _IMAGE_URL_BASE = "http://images.cocodataset.org/zips" - - _IMAGES_CHECKSUMS = { - ("2014", "train"): "ede4087e640bddba550e090eae701092534b554b42b05ac33f0300b984b31775", - ("2014", "val"): "fe9be816052049c34717e077d9e34aa60814a55679f804cd043e3cbee3b9fde0", - ("2017", "train"): "69a8bb58ea5f8f99d24875f21416de2e9ded3178e903f1f7603e283b9e06d929", - ("2017", "val"): "4f7e2ccb2866ec5041993c9cf2a952bbed69647b115d0f74da7ce8f4bef82f05", - } - - _META_URL_BASE = "http://images.cocodataset.org/annotations" - - _META_CHECKSUMS = { - "2014": "031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009", - "2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268", - } - - def _resources(self) -> List[OnlineResource]: - images = HttpResource( - f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip", - sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)], - ) - meta = HttpResource( - f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip", - sha256=self._META_CHECKSUMS[self._year], - ) - return [images, meta] - - def _segmentation_to_mask( - self, segmentation: Any, *, is_crowd: bool, spatial_size: Tuple[int, int] - ) -> torch.Tensor: - from pycocotools import mask - - if is_crowd: - segmentation = mask.frPyObjects(segmentation, *spatial_size) - else: - segmentation = mask.merge(mask.frPyObjects(segmentation, *spatial_size)) - - return torch.from_numpy(mask.decode(segmentation)).to(torch.bool) - - def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: - spatial_size = (image_meta["height"], image_meta["width"]) - labels = [ann["category_id"] for ann in anns] - return dict( - segmentations=Mask( - torch.stack( - [ - self._segmentation_to_mask( - ann["segmentation"], is_crowd=ann["iscrowd"], spatial_size=spatial_size - ) - for ann in anns - ] - ) - ), - areas=torch.as_tensor([ann["area"] for ann in anns]), - crowds=torch.as_tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool), - bounding_boxes=BoundingBox( - [ann["bbox"] for ann in anns], - format="xywh", - spatial_size=spatial_size, - ), - labels=Label(labels, categories=self._categories), - super_categories=[self._category_to_super_category[self._categories[label]] for label in labels], - ann_ids=[ann["id"] for ann in anns], - ) - - def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: - return dict( - captions=[ann["caption"] for ann in anns], - ann_ids=[ann["id"] for ann in anns], - ) - - _ANN_DECODERS = OrderedDict( - [ - ("instances", _decode_instances_anns), - ("captions", _decode_captions_ann), - ] - ) - - _META_FILE_PATTERN = re.compile( - rf"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json" - ) - - def _filter_meta_files(self, data: Tuple[str, Any]) -> bool: - match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name) - return bool( - match - and match["split"] == self._split - and match["year"] == self._year - and match["annotations"] == self._annotations - ) - - def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: - key, _ = data - if key == "images": - return 0 - elif key == "annotations": - return 1 - else: - return None - - def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: - path, buffer = data - return dict( - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _prepare_sample( - self, - data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], - ) -> Dict[str, Any]: - ann_data, image_data = data - anns, image_meta = ann_data - - sample = self._prepare_image(image_data) - # this method is only called if we have annotations - annotations = cast(str, self._annotations) - sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) - return sample - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - images_dp, meta_dp = resource_dps - - if self._annotations is None: - dp = hint_shuffling(images_dp) - dp = hint_sharding(dp) - dp = hint_shuffling(dp) - return Mapper(dp, self._prepare_image) - - meta_dp = Filter(meta_dp, self._filter_meta_files) - meta_dp = JsonParser(meta_dp) - meta_dp = Mapper(meta_dp, getitem(1)) - meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp) - images_meta_dp, anns_meta_dp = Demultiplexer( - meta_dp, - 2, - self._classify_meta, - drop_none=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - images_meta_dp = Mapper(images_meta_dp, getitem(1)) - images_meta_dp = UnBatcher(images_meta_dp) - - anns_meta_dp = Mapper(anns_meta_dp, getitem(1)) - anns_meta_dp = UnBatcher(anns_meta_dp) - anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE) - anns_meta_dp = hint_shuffling(anns_meta_dp) - anns_meta_dp = hint_sharding(anns_meta_dp) - - anns_dp = IterKeyZipper( - anns_meta_dp, - images_meta_dp, - key_fn=getitem(0, "image_id"), - ref_key_fn=getitem("id"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - dp = IterKeyZipper( - anns_dp, - images_dp, - key_fn=getitem(1, "file_name"), - ref_key_fn=path_accessor("name"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - ("train", "2017"): defaultdict(lambda: 118_287, instances=117_266), - ("train", "2014"): defaultdict(lambda: 82_783, instances=82_081), - ("val", "2017"): defaultdict(lambda: 5_000, instances=4_952), - ("val", "2014"): defaultdict(lambda: 40_504, instances=40_137), - }[(self._split, self._year)][ - self._annotations # type: ignore[index] - ] - - def _generate_categories(self) -> Tuple[Tuple[str, str]]: - self._annotations = "instances" - resources = self._resources() - - dp = resources[1].load(self._root) - dp = Filter(dp, self._filter_meta_files) - dp = JsonParser(dp) - - _, meta = next(iter(dp)) - # List[Tuple[super_category, id, category]] - label_data = [cast(Tuple[str, int, str], tuple(info.values())) for info in meta["categories"]] - - # COCO actually defines 91 categories, but only 80 of them have instances. Still, the category_id refers to the - # full set. To keep the labels dense, we fill the gaps with N/A. Note that there are only 10 gaps, so the total - # number of categories is 90 rather than 91. - _, ids, _ = zip(*label_data) - missing_ids = set(range(1, max(ids) + 1)) - set(ids) - label_data.extend([("N/A", id, "N/A") for id in missing_ids]) - - # We also add a background category to be used during segmentation. - label_data.append(("N/A", 0, "__background__")) - - super_categories, _, categories = zip(*sorted(label_data, key=lambda info: info[1])) - - return cast(Tuple[Tuple[str, str]], tuple(zip(categories, super_categories))) diff --git a/torchvision/prototype/datasets/_builtin/country211.categories b/torchvision/prototype/datasets/_builtin/country211.categories deleted file mode 100644 index 6fc3e99a185..00000000000 --- a/torchvision/prototype/datasets/_builtin/country211.categories +++ /dev/null @@ -1,211 +0,0 @@ -AD -AE -AF -AG -AI -AL -AM -AO -AQ -AR -AT -AU -AW -AX -AZ -BA -BB -BD -BE -BF -BG -BH -BJ -BM -BN -BO -BQ -BR -BS -BT -BW -BY -BZ -CA -CD -CF -CH -CI -CK -CL -CM -CN -CO -CR -CU -CV -CW -CY -CZ -DE -DK -DM -DO -DZ -EC -EE -EG -ES -ET -FI -FJ -FK -FO -FR -GA -GB -GD -GE -GF -GG -GH -GI -GL -GM -GP -GR -GS -GT -GU -GY -HK -HN -HR -HT -HU -ID -IE -IL -IM -IN -IQ -IR -IS -IT -JE -JM -JO -JP -KE -KG -KH -KN -KP -KR -KW -KY -KZ -LA -LB -LC -LI -LK -LR -LT -LU -LV -LY -MA -MC -MD -ME -MF -MG -MK -ML -MM -MN -MO -MQ -MR -MT -MU -MV -MW -MX -MY -MZ -NA -NC -NG -NI -NL -NO -NP -NZ -OM -PA -PE -PF -PG -PH -PK -PL -PR -PS -PT -PW -PY -QA -RE -RO -RS -RU -RW -SA -SB -SC -SD -SE -SG -SH -SI -SJ -SK -SL -SM -SN -SO -SS -SV -SX -SY -SZ -TG -TH -TJ -TL -TM -TN -TO -TR -TT -TW -TZ -UA -UG -US -UY -UZ -VA -VE -VG -VI -VN -VU -WS -XK -YE -ZA -ZM -ZW diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py deleted file mode 100644 index 0f4b3d769dc..00000000000 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ /dev/null @@ -1,81 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Tuple, Union - -from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, - path_comparator, - read_categories_file, -) - -from .._api import register_dataset, register_info - -NAME = "country211" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class Country211(Dataset): - """ - - **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) - self._split_folder_name = "valid" if split == "val" else split - - self._categories = _info()["categories"] - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - "https://openaipublic.azureedge.net/clip/data/country211.tgz", - sha256="c011343cdc1296a8c31ff1d7129cf0b5e5b8605462cffd24f89266d6e6f4da3c", - ) - ] - - def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: - path, buffer = data - category = pathlib.Path(path).parent.name - return dict( - label=Label.from_category(category, categories=self._categories), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool: - return pathlib.Path(data[0]).parent.parent.name == split - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name)) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - "train": 31_650, - "val": 10_550, - "test": 21_100, - }[self._split] - - def _generate_categories(self) -> List[str]: - resources = self._resources() - dp = resources[0].load(self._root) - return sorted({pathlib.Path(path).parent.name for path, _ in dp}) diff --git a/torchvision/prototype/datasets/_builtin/cub200.categories b/torchvision/prototype/datasets/_builtin/cub200.categories deleted file mode 100644 index f91754c930c..00000000000 --- a/torchvision/prototype/datasets/_builtin/cub200.categories +++ /dev/null @@ -1,200 +0,0 @@ -Black_footed_Albatross -Laysan_Albatross -Sooty_Albatross -Groove_billed_Ani -Crested_Auklet -Least_Auklet -Parakeet_Auklet -Rhinoceros_Auklet -Brewer_Blackbird -Red_winged_Blackbird -Rusty_Blackbird -Yellow_headed_Blackbird -Bobolink -Indigo_Bunting -Lazuli_Bunting -Painted_Bunting -Cardinal -Spotted_Catbird -Gray_Catbird -Yellow_breasted_Chat -Eastern_Towhee -Chuck_will_Widow -Brandt_Cormorant -Red_faced_Cormorant -Pelagic_Cormorant -Bronzed_Cowbird -Shiny_Cowbird -Brown_Creeper -American_Crow -Fish_Crow -Black_billed_Cuckoo -Mangrove_Cuckoo -Yellow_billed_Cuckoo -Gray_crowned_Rosy_Finch -Purple_Finch -Northern_Flicker -Acadian_Flycatcher -Great_Crested_Flycatcher -Least_Flycatcher -Olive_sided_Flycatcher -Scissor_tailed_Flycatcher -Vermilion_Flycatcher -Yellow_bellied_Flycatcher -Frigatebird -Northern_Fulmar -Gadwall -American_Goldfinch -European_Goldfinch -Boat_tailed_Grackle -Eared_Grebe -Horned_Grebe -Pied_billed_Grebe -Western_Grebe -Blue_Grosbeak -Evening_Grosbeak -Pine_Grosbeak -Rose_breasted_Grosbeak -Pigeon_Guillemot -California_Gull -Glaucous_winged_Gull -Heermann_Gull -Herring_Gull -Ivory_Gull -Ring_billed_Gull -Slaty_backed_Gull -Western_Gull -Anna_Hummingbird -Ruby_throated_Hummingbird -Rufous_Hummingbird -Green_Violetear -Long_tailed_Jaeger -Pomarine_Jaeger -Blue_Jay -Florida_Jay -Green_Jay -Dark_eyed_Junco -Tropical_Kingbird -Gray_Kingbird -Belted_Kingfisher -Green_Kingfisher -Pied_Kingfisher -Ringed_Kingfisher -White_breasted_Kingfisher -Red_legged_Kittiwake -Horned_Lark -Pacific_Loon -Mallard -Western_Meadowlark -Hooded_Merganser -Red_breasted_Merganser -Mockingbird -Nighthawk -Clark_Nutcracker -White_breasted_Nuthatch -Baltimore_Oriole -Hooded_Oriole -Orchard_Oriole -Scott_Oriole -Ovenbird -Brown_Pelican -White_Pelican -Western_Wood_Pewee -Sayornis -American_Pipit -Whip_poor_Will -Horned_Puffin -Common_Raven -White_necked_Raven -American_Redstart -Geococcyx -Loggerhead_Shrike -Great_Grey_Shrike -Baird_Sparrow -Black_throated_Sparrow -Brewer_Sparrow -Chipping_Sparrow -Clay_colored_Sparrow -House_Sparrow -Field_Sparrow -Fox_Sparrow -Grasshopper_Sparrow -Harris_Sparrow -Henslow_Sparrow -Le_Conte_Sparrow -Lincoln_Sparrow -Nelson_Sharp_tailed_Sparrow -Savannah_Sparrow -Seaside_Sparrow -Song_Sparrow -Tree_Sparrow -Vesper_Sparrow -White_crowned_Sparrow -White_throated_Sparrow -Cape_Glossy_Starling -Bank_Swallow -Barn_Swallow -Cliff_Swallow -Tree_Swallow -Scarlet_Tanager -Summer_Tanager -Artic_Tern -Black_Tern -Caspian_Tern -Common_Tern -Elegant_Tern -Forsters_Tern -Least_Tern -Green_tailed_Towhee -Brown_Thrasher -Sage_Thrasher -Black_capped_Vireo -Blue_headed_Vireo -Philadelphia_Vireo -Red_eyed_Vireo -Warbling_Vireo -White_eyed_Vireo -Yellow_throated_Vireo -Bay_breasted_Warbler -Black_and_white_Warbler -Black_throated_Blue_Warbler -Blue_winged_Warbler -Canada_Warbler -Cape_May_Warbler -Cerulean_Warbler -Chestnut_sided_Warbler -Golden_winged_Warbler -Hooded_Warbler -Kentucky_Warbler -Magnolia_Warbler -Mourning_Warbler -Myrtle_Warbler -Nashville_Warbler -Orange_crowned_Warbler -Palm_Warbler -Pine_Warbler -Prairie_Warbler -Prothonotary_Warbler -Swainson_Warbler -Tennessee_Warbler -Wilson_Warbler -Worm_eating_Warbler -Yellow_Warbler -Northern_Waterthrush -Louisiana_Waterthrush -Bohemian_Waxwing -Cedar_Waxwing -American_Three_toed_Woodpecker -Pileated_Woodpecker -Red_bellied_Woodpecker -Red_cockaded_Woodpecker -Red_headed_Woodpecker -Downy_Woodpecker -Bewick_Wren -Cactus_Wren -Carolina_Wren -House_Wren -Marsh_Wren -Rock_Wren -Winter_Wren -Common_Yellowthroat diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py deleted file mode 100644 index bc41ba028c5..00000000000 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ /dev/null @@ -1,265 +0,0 @@ -import csv -import functools -import pathlib -from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union - -import torch -from torchdata.datapipes.iter import ( - CSVDictParser, - CSVParser, - Demultiplexer, - Filter, - IterDataPipe, - IterKeyZipper, - LineReader, - Mapper, -) -from torchdata.datapipes.map import IterToMapConverter -from torchvision.datapoints import BoundingBox -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - path_comparator, - read_categories_file, - read_mat, -) - -from .._api import register_dataset, register_info - -csv.register_dialect("cub200", delimiter=" ") - - -NAME = "cub200" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class CUB200(Dataset): - """ - - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - year: str = "2011", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test")) - self._year = self._verify_str_arg(year, "year", ("2010", "2011")) - - self._categories = _info()["categories"] - - super().__init__( - root, - # TODO: this will only be available after https://github.com/pytorch/vision/pull/5473 - # dependencies=("scipy",), - skip_integrity_check=skip_integrity_check, - ) - - def _resources(self) -> List[OnlineResource]: - if self._year == "2011": - archive = GDriveResource( - "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45", - file_name="CUB_200_2011.tgz", - sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081", - preprocess="decompress", - ) - segmentations = GDriveResource( - "1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP", - file_name="segmentations.tgz", - sha256="dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f", - preprocess="decompress", - ) - return [archive, segmentations] - else: # self._year == "2010" - split = GDriveResource( - "1vZuZPqha0JjmwkdaS_XtYryE3Jf5Q1AC", - file_name="lists.tgz", - sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428", - preprocess="decompress", - ) - images = GDriveResource( - "1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx", - file_name="images.tgz", - sha256="2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e", - preprocess="decompress", - ) - anns = GDriveResource( - "16NsbTpMs5L6hT4hUJAmpW2u7wH326WTR", - file_name="annotations.tgz", - sha256="c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1", - preprocess="decompress", - ) - return [split, images, anns] - - def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - if path.parents[1].name == "images": - return 0 - elif path.name == "train_test_split.txt": - return 1 - elif path.name == "images.txt": - return 2 - elif path.name == "bounding_boxes.txt": - return 3 - else: - return None - - def _2011_extract_file_name(self, rel_posix_path: str) -> str: - return rel_posix_path.rsplit("/", maxsplit=1)[1] - - def _2011_filter_split(self, row: List[str]) -> bool: - _, split_id = row - return { - "0": "test", - "1": "train", - }[split_id] == self._split - - def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: - path = pathlib.Path(data[0]) - return path.with_suffix(".jpg").name - - def _2011_prepare_ann( - self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], spatial_size: Tuple[int, int] - ) -> Dict[str, Any]: - _, (bounding_box_data, segmentation_data) = data - segmentation_path, segmentation_buffer = segmentation_data - return dict( - bounding_box=BoundingBox( - [float(part) for part in bounding_box_data[1:]], format="xywh", spatial_size=spatial_size - ), - segmentation_path=segmentation_path, - segmentation=EncodedImage.from_file(segmentation_buffer), - ) - - def _2010_split_key(self, data: str) -> str: - return data.rsplit("/", maxsplit=1)[1] - - def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, BinaryIO]]: - path = pathlib.Path(data[0]) - return path.with_suffix(".jpg").name, data - - def _2010_prepare_ann( - self, data: Tuple[str, Tuple[str, BinaryIO]], spatial_size: Tuple[int, int] - ) -> Dict[str, Any]: - _, (path, buffer) = data - content = read_mat(buffer) - return dict( - ann_path=path, - bounding_box=BoundingBox( - [int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], - format="xyxy", - spatial_size=spatial_size, - ), - segmentation=torch.as_tensor(content["seg"]), - ) - - def _prepare_sample( - self, - data: Tuple[Tuple[str, Tuple[str, BinaryIO]], Any], - *, - prepare_ann_fn: Callable[[Any, Tuple[int, int]], Dict[str, Any]], - ) -> Dict[str, Any]: - data, anns_data = data - _, image_data = data - path, buffer = image_data - - image = EncodedImage.from_file(buffer) - - return dict( - prepare_ann_fn(anns_data, image.spatial_size), - image=image, - label=Label( - int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]) - 1, - categories=self._categories, - ), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - prepare_ann_fn: Callable - if self._year == "2011": - archive_dp, segmentations_dp = resource_dps - images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer( - archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - - image_files_dp = CSVParser(image_files_dp, dialect="cub200") - image_files_dp = Mapper(image_files_dp, self._2011_extract_file_name, input_col=1) - image_files_map = IterToMapConverter(image_files_dp) - - split_dp = CSVParser(split_dp, dialect="cub200") - split_dp = Filter(split_dp, self._2011_filter_split) - split_dp = Mapper(split_dp, getitem(0)) - split_dp = Mapper(split_dp, image_files_map.__getitem__) - - bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200") - bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0) - - anns_dp = IterKeyZipper( - bounding_boxes_dp, - segmentations_dp, - key_fn=getitem(0), - ref_key_fn=self._2011_segmentation_key, - keep_key=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - prepare_ann_fn = self._2011_prepare_ann - else: # self._year == "2010" - split_dp, images_dp, anns_dp = resource_dps - - split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) - split_dp = LineReader(split_dp, decode=True, return_path=False) - split_dp = Mapper(split_dp, self._2010_split_key) - - anns_dp = Mapper(anns_dp, self._2010_anns_key) - - prepare_ann_fn = self._2010_prepare_ann - - split_dp = hint_shuffling(split_dp) - split_dp = hint_sharding(split_dp) - - dp = IterKeyZipper( - split_dp, - images_dp, - getitem(), - path_accessor("name"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - dp = IterKeyZipper( - dp, - anns_dp, - getitem(0), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn)) - - def __len__(self) -> int: - return { - ("train", "2010"): 3_000, - ("test", "2010"): 3_033, - ("train", "2011"): 5_994, - ("test", "2011"): 5_794, - }[(self._split, self._year)] - - def _generate_categories(self) -> List[str]: - self._year = "2011" - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, path_comparator("name", "classes.txt")) - dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200") - - return [row["category"].split(".")[1] for row in dp] diff --git a/torchvision/prototype/datasets/_builtin/dtd.categories b/torchvision/prototype/datasets/_builtin/dtd.categories deleted file mode 100644 index 7f3df8a2b00..00000000000 --- a/torchvision/prototype/datasets/_builtin/dtd.categories +++ /dev/null @@ -1,47 +0,0 @@ -banded -blotchy -braided -bubbly -bumpy -chequered -cobwebbed -cracked -crosshatched -crystalline -dotted -fibrous -flecked -freckled -frilly -gauzy -grid -grooved -honeycombed -interlaced -knitted -lacelike -lined -marbled -matted -meshed -paisley -perforated -pitted -pleated -polka-dotted -porous -potholed -scaly -smeared -spiralled -sprinkled -stained -stratified -striped -studded -swirly -veined -waffled -woven -wrinkled -zigzagged diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py deleted file mode 100644 index 6ddab2af79d..00000000000 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ /dev/null @@ -1,139 +0,0 @@ -import enum -import pathlib -from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_comparator, - read_categories_file, -) - -from .._api import register_dataset, register_info - - -NAME = "dtd" - - -class DTDDemux(enum.IntEnum): - SPLIT = 0 - JOINT_CATEGORIES = 1 - IMAGES = 2 - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class DTD(Dataset): - """DTD Dataset. - homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - fold: int = 1, - skip_validation_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) - - if not (1 <= fold <= 10): - raise ValueError(f"The fold parameter should be an integer in [1, 10]. Got {fold}") - self._fold = fold - - self._categories = _info()["categories"] - - super().__init__(root, skip_integrity_check=skip_validation_check) - - def _resources(self) -> List[OnlineResource]: - archive = HttpResource( - "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", - sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", - preprocess="decompress", - ) - return [archive] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - if path.parent.name == "labels": - if path.name == "labels_joint_anno.txt": - return DTDDemux.JOINT_CATEGORIES - - return DTDDemux.SPLIT - elif path.parents[1].name == "images": - return DTDDemux.IMAGES - else: - return None - - def _image_key_fn(self, data: Tuple[str, Any]) -> str: - path = pathlib.Path(data[0]) - # The split files contain hardcoded posix paths for the images, e.g. banded/banded_0001.jpg - return str(path.relative_to(path.parents[1]).as_posix()) - - def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: - (_, joint_categories_data), image_data = data - _, *joint_categories = joint_categories_data - path, buffer = image_data - - category = pathlib.Path(path).parent.name - - return dict( - joint_categories={category for category in joint_categories if category}, - label=Label.from_category(category, categories=self._categories), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] - - splits_dp, joint_categories_dp, images_dp = Demultiplexer( - archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - - splits_dp = Filter(splits_dp, path_comparator("name", f"{self._split}{self._fold}.txt")) - splits_dp = LineReader(splits_dp, decode=True, return_path=False) - splits_dp = hint_shuffling(splits_dp) - splits_dp = hint_sharding(splits_dp) - - joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ") - - dp = IterKeyZipper( - splits_dp, - joint_categories_dp, - key_fn=getitem(), - ref_key_fn=getitem(0), - buffer_size=INFINITE_BUFFER_SIZE, - ) - dp = IterKeyZipper( - dp, - images_dp, - key_fn=getitem(0), - ref_key_fn=self._image_key_fn, - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def _filter_images(self, data: Tuple[str, Any]) -> bool: - return self._classify_archive(data) == DTDDemux.IMAGES - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, self._filter_images) - - return sorted({pathlib.Path(path).parent.name for path, _ in dp}) - - def __len__(self) -> int: - return 1_880 # All splits have the same length diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py deleted file mode 100644 index 463eed79d70..00000000000 --- a/torchvision/prototype/datasets/_builtin/eurosat.py +++ /dev/null @@ -1,66 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Tuple, Union - -from torchdata.datapipes.iter import IterDataPipe, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling - -from .._api import register_dataset, register_info - -NAME = "eurosat" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict( - categories=( - "AnnualCrop", - "Forest", - "HerbaceousVegetation", - "Highway", - "Industrial", - "Pasture", - "PermanentCrop", - "Residential", - "River", - "SeaLake", - ) - ) - - -@register_dataset(NAME) -class EuroSAT(Dataset): - """EuroSAT Dataset. - homepage="https://github.com/phelber/eurosat", - """ - - def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - "https://madm.dfki.de/files/sentinel/EuroSAT.zip", - sha256="8ebea626349354c5328b142b96d0430e647051f26efc2dc974c843f25ecf70bd", - ) - ] - - def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: - path, buffer = data - category = pathlib.Path(path).parent.name - return dict( - label=Label.from_category(category, categories=self._categories), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 27_000 diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py deleted file mode 100644 index 17f092aa328..00000000000 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ /dev/null @@ -1,64 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Union - -import torch -from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper -from torchvision.datapoints import Image -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling - -from .._api import register_dataset, register_info - -NAME = "fer2013" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral")) - - -@register_dataset(NAME) -class FER2013(Dataset): - """FER 2013 Dataset - homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" - """ - - def __init__( - self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "test"}) - self._categories = _info()["categories"] - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _CHECKSUMS = { - "train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10", - "test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3", - } - - def _resources(self) -> List[OnlineResource]: - archive = KaggleDownloadResource( - "https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", - file_name=f"{self._split}.csv.zip", - sha256=self._CHECKSUMS[self._split], - ) - return [archive] - - def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: - label_id = data.get("emotion") - - return dict( - image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)), - label=Label(int(label_id), categories=self._categories) if label_id is not None else None, - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = CSVDictParser(dp) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 28_709 if self._split == "train" else 3_589 diff --git a/torchvision/prototype/datasets/_builtin/food101.categories b/torchvision/prototype/datasets/_builtin/food101.categories deleted file mode 100644 index 59f252ddff4..00000000000 --- a/torchvision/prototype/datasets/_builtin/food101.categories +++ /dev/null @@ -1,101 +0,0 @@ -apple_pie -baby_back_ribs -baklava -beef_carpaccio -beef_tartare -beet_salad -beignets -bibimbap -bread_pudding -breakfast_burrito -bruschetta -caesar_salad -cannoli -caprese_salad -carrot_cake -ceviche -cheesecake -cheese_plate -chicken_curry -chicken_quesadilla -chicken_wings -chocolate_cake -chocolate_mousse -churros -clam_chowder -club_sandwich -crab_cakes -creme_brulee -croque_madame -cup_cakes -deviled_eggs -donuts -dumplings -edamame -eggs_benedict -escargots -falafel -filet_mignon -fish_and_chips -foie_gras -french_fries -french_onion_soup -french_toast -fried_calamari -fried_rice -frozen_yogurt -garlic_bread -gnocchi -greek_salad -grilled_cheese_sandwich -grilled_salmon -guacamole -gyoza -hamburger -hot_and_sour_soup -hot_dog -huevos_rancheros -hummus -ice_cream -lasagna -lobster_bisque -lobster_roll_sandwich -macaroni_and_cheese -macarons -miso_soup -mussels -nachos -omelette -onion_rings -oysters -pad_thai -paella -pancakes -panna_cotta -peking_duck -pho -pizza -pork_chop -poutine -prime_rib -pulled_pork_sandwich -ramen -ravioli -red_velvet_cake -risotto -samosa -sashimi -scallops -seaweed_salad -shrimp_and_grits -spaghetti_bolognese -spaghetti_carbonara -spring_rolls -steak -strawberry_shortcake -sushi -tacos -takoyaki -tiramisu -tuna_tartare -waffles diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py deleted file mode 100644 index f3054d8fb13..00000000000 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ /dev/null @@ -1,97 +0,0 @@ -from pathlib import Path -from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_comparator, - read_categories_file, -) - -from .._api import register_dataset, register_info - - -NAME = "food101" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class Food101(Dataset): - """Food 101 dataset - homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101", - """ - - def __init__(self, root: Union[str, Path], *, split: str = "train", skip_integrity_check: bool = False) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "test"}) - self._categories = _info()["categories"] - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz", - sha256="d97d15e438b7f4498f96086a4f7e2fa42a32f2712e87d3295441b2b6314053a4", - preprocess="decompress", - ) - ] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = Path(data[0]) - if path.parents[1].name == "images": - return 0 - elif path.parents[0].name == "meta": - return 1 - else: - return None - - def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]: - id, (path, buffer) = data - return dict( - label=Label.from_category(id.split("/", 1)[0], categories=self._categories), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _image_key(self, data: Tuple[str, Any]) -> str: - path = Path(data[0]) - return path.relative_to(path.parents[1]).with_suffix("").as_posix() - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] - images_dp, split_dp = Demultiplexer( - archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) - split_dp = LineReader(split_dp, decode=True, return_path=False) - split_dp = hint_sharding(split_dp) - split_dp = hint_shuffling(split_dp) - - dp = IterKeyZipper( - split_dp, - images_dp, - key_fn=getitem(), - ref_key_fn=self._image_key, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - return Mapper(dp, self._prepare_sample) - - def _generate_categories(self) -> List[str]: - resources = self._resources() - dp = resources[0].load(self._root) - dp = Filter(dp, path_comparator("name", "classes.txt")) - dp = LineReader(dp, decode=True, return_path=False) - return list(dp) - - def __len__(self) -> int: - return 75_750 if self._split == "train" else 25_250 diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py deleted file mode 100644 index 85116ca3860..00000000000 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ /dev/null @@ -1,112 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper -from torchvision.datapoints import BoundingBox -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_comparator, -) - -from .._api import register_dataset, register_info - -NAME = "gtsrb" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict( - categories=[f"{label:05d}" for label in range(43)], - ) - - -@register_dataset(NAME) -class GTSRB(Dataset): - """GTSRB Dataset - - homepage="https://benchmark.ini.rub.de" - """ - - def __init__( - self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "test"}) - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" - _URLS = { - "train": f"{_URL_ROOT}GTSRB-Training_fixed.zip", - "test": f"{_URL_ROOT}GTSRB_Final_Test_Images.zip", - "test_ground_truth": f"{_URL_ROOT}GTSRB_Final_Test_GT.zip", - } - _CHECKSUMS = { - "train": "df4144942083645bd60b594de348aa6930126c3e0e5de09e39611630abf8455a", - "test": "48ba6fab7e877eb64eaf8de99035b0aaecfbc279bee23e35deca4ac1d0a837fa", - "test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d", - } - - def _resources(self) -> List[OnlineResource]: - rsrcs: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUMS[self._split])] - - if self._split == "test": - rsrcs.append( - HttpResource( - self._URLS["test_ground_truth"], - sha256=self._CHECKSUMS["test_ground_truth"], - ) - ) - - return rsrcs - - def _classify_train_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - if path.suffix == ".ppm": - return 0 - elif path.suffix == ".csv": - return 1 - else: - return None - - def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[str, Any]: - (path, buffer), csv_info = data - label = int(csv_info["ClassId"]) - - bounding_box = BoundingBox( - [int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")], - format="xyxy", - spatial_size=(int(csv_info["Height"]), int(csv_info["Width"])), - ) - - return { - "path": path, - "image": EncodedImage.from_file(buffer), - "label": Label(label, categories=self._categories), - "bounding_box": bounding_box, - } - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - if self._split == "train": - images_dp, ann_dp = Demultiplexer( - resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - else: - images_dp, ann_dp = resource_dps - images_dp = Filter(images_dp, path_comparator("suffix", ".ppm")) - - # The order of the image files in the .zip archives perfectly match the order of the entries in the - # (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper. - ann_dp = CSVDictParser(ann_dp, delimiter=";") - dp = Zipper(images_dp, ann_dp) - - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 26_640 if self._split == "train" else 12_630 diff --git a/torchvision/prototype/datasets/_builtin/imagenet.categories b/torchvision/prototype/datasets/_builtin/imagenet.categories deleted file mode 100644 index 7b6006ff57f..00000000000 --- a/torchvision/prototype/datasets/_builtin/imagenet.categories +++ /dev/null @@ -1,1000 +0,0 @@ -tench,n01440764 -goldfish,n01443537 -great white shark,n01484850 -tiger shark,n01491361 -hammerhead,n01494475 -electric ray,n01496331 -stingray,n01498041 -cock,n01514668 -hen,n01514859 -ostrich,n01518878 -brambling,n01530575 -goldfinch,n01531178 -house finch,n01532829 -junco,n01534433 -indigo bunting,n01537544 -robin,n01558993 -bulbul,n01560419 -jay,n01580077 -magpie,n01582220 -chickadee,n01592084 -water ouzel,n01601694 -kite,n01608432 -bald eagle,n01614925 -vulture,n01616318 -great grey owl,n01622779 -European fire salamander,n01629819 -common newt,n01630670 -eft,n01631663 -spotted salamander,n01632458 -axolotl,n01632777 -bullfrog,n01641577 -tree frog,n01644373 -tailed frog,n01644900 -loggerhead,n01664065 -leatherback turtle,n01665541 -mud turtle,n01667114 -terrapin,n01667778 -box turtle,n01669191 -banded gecko,n01675722 -common iguana,n01677366 -American chameleon,n01682714 -whiptail,n01685808 -agama,n01687978 -frilled lizard,n01688243 -alligator lizard,n01689811 -Gila monster,n01692333 -green lizard,n01693334 -African chameleon,n01694178 -Komodo dragon,n01695060 -African crocodile,n01697457 -American alligator,n01698640 -triceratops,n01704323 -thunder snake,n01728572 -ringneck snake,n01728920 -hognose snake,n01729322 -green snake,n01729977 -king snake,n01734418 -garter snake,n01735189 -water snake,n01737021 -vine snake,n01739381 -night snake,n01740131 -boa constrictor,n01742172 -rock python,n01744401 -Indian cobra,n01748264 -green mamba,n01749939 -sea snake,n01751748 -horned viper,n01753488 -diamondback,n01755581 -sidewinder,n01756291 -trilobite,n01768244 -harvestman,n01770081 -scorpion,n01770393 -black and gold garden spider,n01773157 -barn spider,n01773549 -garden spider,n01773797 -black widow,n01774384 -tarantula,n01774750 -wolf spider,n01775062 -tick,n01776313 -centipede,n01784675 -black grouse,n01795545 -ptarmigan,n01796340 -ruffed grouse,n01797886 -prairie chicken,n01798484 -peacock,n01806143 -quail,n01806567 -partridge,n01807496 -African grey,n01817953 -macaw,n01818515 -sulphur-crested cockatoo,n01819313 -lorikeet,n01820546 -coucal,n01824575 -bee eater,n01828970 -hornbill,n01829413 -hummingbird,n01833805 -jacamar,n01843065 -toucan,n01843383 -drake,n01847000 -red-breasted merganser,n01855032 -goose,n01855672 -black swan,n01860187 -tusker,n01871265 -echidna,n01872401 -platypus,n01873310 -wallaby,n01877812 -koala,n01882714 -wombat,n01883070 -jellyfish,n01910747 -sea anemone,n01914609 -brain coral,n01917289 -flatworm,n01924916 -nematode,n01930112 -conch,n01943899 -snail,n01944390 -slug,n01945685 -sea slug,n01950731 -chiton,n01955084 -chambered nautilus,n01968897 -Dungeness crab,n01978287 -rock crab,n01978455 -fiddler crab,n01980166 -king crab,n01981276 -American lobster,n01983481 -spiny lobster,n01984695 -crayfish,n01985128 -hermit crab,n01986214 -isopod,n01990800 -white stork,n02002556 -black stork,n02002724 -spoonbill,n02006656 -flamingo,n02007558 -little blue heron,n02009229 -American egret,n02009912 -bittern,n02011460 -crane,n02012849 -limpkin,n02013706 -European gallinule,n02017213 -American coot,n02018207 -bustard,n02018795 -ruddy turnstone,n02025239 -red-backed sandpiper,n02027492 -redshank,n02028035 -dowitcher,n02033041 -oystercatcher,n02037110 -pelican,n02051845 -king penguin,n02056570 -albatross,n02058221 -grey whale,n02066245 -killer whale,n02071294 -dugong,n02074367 -sea lion,n02077923 -Chihuahua,n02085620 -Japanese spaniel,n02085782 -Maltese dog,n02085936 -Pekinese,n02086079 -Shih-Tzu,n02086240 -Blenheim spaniel,n02086646 -papillon,n02086910 -toy terrier,n02087046 -Rhodesian ridgeback,n02087394 -Afghan hound,n02088094 -basset,n02088238 -beagle,n02088364 -bloodhound,n02088466 -bluetick,n02088632 -black-and-tan coonhound,n02089078 -Walker hound,n02089867 -English foxhound,n02089973 -redbone,n02090379 -borzoi,n02090622 -Irish wolfhound,n02090721 -Italian greyhound,n02091032 -whippet,n02091134 -Ibizan hound,n02091244 -Norwegian elkhound,n02091467 -otterhound,n02091635 -Saluki,n02091831 -Scottish deerhound,n02092002 -Weimaraner,n02092339 -Staffordshire bullterrier,n02093256 -American Staffordshire terrier,n02093428 -Bedlington terrier,n02093647 -Border terrier,n02093754 -Kerry blue terrier,n02093859 -Irish terrier,n02093991 -Norfolk terrier,n02094114 -Norwich terrier,n02094258 -Yorkshire terrier,n02094433 -wire-haired fox terrier,n02095314 -Lakeland terrier,n02095570 -Sealyham terrier,n02095889 -Airedale,n02096051 -cairn,n02096177 -Australian terrier,n02096294 -Dandie Dinmont,n02096437 -Boston bull,n02096585 -miniature schnauzer,n02097047 -giant schnauzer,n02097130 -standard schnauzer,n02097209 -Scotch terrier,n02097298 -Tibetan terrier,n02097474 -silky terrier,n02097658 -soft-coated wheaten terrier,n02098105 -West Highland white terrier,n02098286 -Lhasa,n02098413 -flat-coated retriever,n02099267 -curly-coated retriever,n02099429 -golden retriever,n02099601 -Labrador retriever,n02099712 -Chesapeake Bay retriever,n02099849 -German short-haired pointer,n02100236 -vizsla,n02100583 -English setter,n02100735 -Irish setter,n02100877 -Gordon setter,n02101006 -Brittany spaniel,n02101388 -clumber,n02101556 -English springer,n02102040 -Welsh springer spaniel,n02102177 -cocker spaniel,n02102318 -Sussex spaniel,n02102480 -Irish water spaniel,n02102973 -kuvasz,n02104029 -schipperke,n02104365 -groenendael,n02105056 -malinois,n02105162 -briard,n02105251 -kelpie,n02105412 -komondor,n02105505 -Old English sheepdog,n02105641 -Shetland sheepdog,n02105855 -collie,n02106030 -Border collie,n02106166 -Bouvier des Flandres,n02106382 -Rottweiler,n02106550 -German shepherd,n02106662 -Doberman,n02107142 -miniature pinscher,n02107312 -Greater Swiss Mountain dog,n02107574 -Bernese mountain dog,n02107683 -Appenzeller,n02107908 -EntleBucher,n02108000 -boxer,n02108089 -bull mastiff,n02108422 -Tibetan mastiff,n02108551 -French bulldog,n02108915 -Great Dane,n02109047 -Saint Bernard,n02109525 -Eskimo dog,n02109961 -malamute,n02110063 -Siberian husky,n02110185 -dalmatian,n02110341 -affenpinscher,n02110627 -basenji,n02110806 -pug,n02110958 -Leonberg,n02111129 -Newfoundland,n02111277 -Great Pyrenees,n02111500 -Samoyed,n02111889 -Pomeranian,n02112018 -chow,n02112137 -keeshond,n02112350 -Brabancon griffon,n02112706 -Pembroke,n02113023 -Cardigan,n02113186 -toy poodle,n02113624 -miniature poodle,n02113712 -standard poodle,n02113799 -Mexican hairless,n02113978 -timber wolf,n02114367 -white wolf,n02114548 -red wolf,n02114712 -coyote,n02114855 -dingo,n02115641 -dhole,n02115913 -African hunting dog,n02116738 -hyena,n02117135 -red fox,n02119022 -kit fox,n02119789 -Arctic fox,n02120079 -grey fox,n02120505 -tabby,n02123045 -tiger cat,n02123159 -Persian cat,n02123394 -Siamese cat,n02123597 -Egyptian cat,n02124075 -cougar,n02125311 -lynx,n02127052 -leopard,n02128385 -snow leopard,n02128757 -jaguar,n02128925 -lion,n02129165 -tiger,n02129604 -cheetah,n02130308 -brown bear,n02132136 -American black bear,n02133161 -ice bear,n02134084 -sloth bear,n02134418 -mongoose,n02137549 -meerkat,n02138441 -tiger beetle,n02165105 -ladybug,n02165456 -ground beetle,n02167151 -long-horned beetle,n02168699 -leaf beetle,n02169497 -dung beetle,n02172182 -rhinoceros beetle,n02174001 -weevil,n02177972 -fly,n02190166 -bee,n02206856 -ant,n02219486 -grasshopper,n02226429 -cricket,n02229544 -walking stick,n02231487 -cockroach,n02233338 -mantis,n02236044 -cicada,n02256656 -leafhopper,n02259212 -lacewing,n02264363 -dragonfly,n02268443 -damselfly,n02268853 -admiral,n02276258 -ringlet,n02277742 -monarch,n02279972 -cabbage butterfly,n02280649 -sulphur butterfly,n02281406 -lycaenid,n02281787 -starfish,n02317335 -sea urchin,n02319095 -sea cucumber,n02321529 -wood rabbit,n02325366 -hare,n02326432 -Angora,n02328150 -hamster,n02342885 -porcupine,n02346627 -fox squirrel,n02356798 -marmot,n02361337 -beaver,n02363005 -guinea pig,n02364673 -sorrel,n02389026 -zebra,n02391049 -hog,n02395406 -wild boar,n02396427 -warthog,n02397096 -hippopotamus,n02398521 -ox,n02403003 -water buffalo,n02408429 -bison,n02410509 -ram,n02412080 -bighorn,n02415577 -ibex,n02417914 -hartebeest,n02422106 -impala,n02422699 -gazelle,n02423022 -Arabian camel,n02437312 -llama,n02437616 -weasel,n02441942 -mink,n02442845 -polecat,n02443114 -black-footed ferret,n02443484 -otter,n02444819 -skunk,n02445715 -badger,n02447366 -armadillo,n02454379 -three-toed sloth,n02457408 -orangutan,n02480495 -gorilla,n02480855 -chimpanzee,n02481823 -gibbon,n02483362 -siamang,n02483708 -guenon,n02484975 -patas,n02486261 -baboon,n02486410 -macaque,n02487347 -langur,n02488291 -colobus,n02488702 -proboscis monkey,n02489166 -marmoset,n02490219 -capuchin,n02492035 -howler monkey,n02492660 -titi,n02493509 -spider monkey,n02493793 -squirrel monkey,n02494079 -Madagascar cat,n02497673 -indri,n02500267 -Indian elephant,n02504013 -African elephant,n02504458 -lesser panda,n02509815 -giant panda,n02510455 -barracouta,n02514041 -eel,n02526121 -coho,n02536864 -rock beauty,n02606052 -anemone fish,n02607072 -sturgeon,n02640242 -gar,n02641379 -lionfish,n02643566 -puffer,n02655020 -abacus,n02666196 -abaya,n02667093 -academic gown,n02669723 -accordion,n02672831 -acoustic guitar,n02676566 -aircraft carrier,n02687172 -airliner,n02690373 -airship,n02692877 -altar,n02699494 -ambulance,n02701002 -amphibian,n02704792 -analog clock,n02708093 -apiary,n02727426 -apron,n02730930 -ashcan,n02747177 -assault rifle,n02749479 -backpack,n02769748 -bakery,n02776631 -balance beam,n02777292 -balloon,n02782093 -ballpoint,n02783161 -Band Aid,n02786058 -banjo,n02787622 -bannister,n02788148 -barbell,n02790996 -barber chair,n02791124 -barbershop,n02791270 -barn,n02793495 -barometer,n02794156 -barrel,n02795169 -barrow,n02797295 -baseball,n02799071 -basketball,n02802426 -bassinet,n02804414 -bassoon,n02804610 -bathing cap,n02807133 -bath towel,n02808304 -bathtub,n02808440 -beach wagon,n02814533 -beacon,n02814860 -beaker,n02815834 -bearskin,n02817516 -beer bottle,n02823428 -beer glass,n02823750 -bell cote,n02825657 -bib,n02834397 -bicycle-built-for-two,n02835271 -bikini,n02837789 -binder,n02840245 -binoculars,n02841315 -birdhouse,n02843684 -boathouse,n02859443 -bobsled,n02860847 -bolo tie,n02865351 -bonnet,n02869837 -bookcase,n02870880 -bookshop,n02871525 -bottlecap,n02877765 -bow,n02879718 -bow tie,n02883205 -brass,n02892201 -brassiere,n02892767 -breakwater,n02894605 -breastplate,n02895154 -broom,n02906734 -bucket,n02909870 -buckle,n02910353 -bulletproof vest,n02916936 -bullet train,n02917067 -butcher shop,n02927161 -cab,n02930766 -caldron,n02939185 -candle,n02948072 -cannon,n02950826 -canoe,n02951358 -can opener,n02951585 -cardigan,n02963159 -car mirror,n02965783 -carousel,n02966193 -carpenter's kit,n02966687 -carton,n02971356 -car wheel,n02974003 -cash machine,n02977058 -cassette,n02978881 -cassette player,n02979186 -castle,n02980441 -catamaran,n02981792 -CD player,n02988304 -cello,n02992211 -cellular telephone,n02992529 -chain,n02999410 -chainlink fence,n03000134 -chain mail,n03000247 -chain saw,n03000684 -chest,n03014705 -chiffonier,n03016953 -chime,n03017168 -china cabinet,n03018349 -Christmas stocking,n03026506 -church,n03028079 -cinema,n03032252 -cleaver,n03041632 -cliff dwelling,n03042490 -cloak,n03045698 -clog,n03047690 -cocktail shaker,n03062245 -coffee mug,n03063599 -coffeepot,n03063689 -coil,n03065424 -combination lock,n03075370 -computer keyboard,n03085013 -confectionery,n03089624 -container ship,n03095699 -convertible,n03100240 -corkscrew,n03109150 -cornet,n03110669 -cowboy boot,n03124043 -cowboy hat,n03124170 -cradle,n03125729 -construction crane,n03126707 -crash helmet,n03127747 -crate,n03127925 -crib,n03131574 -Crock Pot,n03133878 -croquet ball,n03134739 -crutch,n03141823 -cuirass,n03146219 -dam,n03160309 -desk,n03179701 -desktop computer,n03180011 -dial telephone,n03187595 -diaper,n03188531 -digital clock,n03196217 -digital watch,n03197337 -dining table,n03201208 -dishrag,n03207743 -dishwasher,n03207941 -disk brake,n03208938 -dock,n03216828 -dogsled,n03218198 -dome,n03220513 -doormat,n03223299 -drilling platform,n03240683 -drum,n03249569 -drumstick,n03250847 -dumbbell,n03255030 -Dutch oven,n03259280 -electric fan,n03271574 -electric guitar,n03272010 -electric locomotive,n03272562 -entertainment center,n03290653 -envelope,n03291819 -espresso maker,n03297495 -face powder,n03314780 -feather boa,n03325584 -file,n03337140 -fireboat,n03344393 -fire engine,n03345487 -fire screen,n03347037 -flagpole,n03355925 -flute,n03372029 -folding chair,n03376595 -football helmet,n03379051 -forklift,n03384352 -fountain,n03388043 -fountain pen,n03388183 -four-poster,n03388549 -freight car,n03393912 -French horn,n03394916 -frying pan,n03400231 -fur coat,n03404251 -garbage truck,n03417042 -gasmask,n03424325 -gas pump,n03425413 -goblet,n03443371 -go-kart,n03444034 -golf ball,n03445777 -golfcart,n03445924 -gondola,n03447447 -gong,n03447721 -gown,n03450230 -grand piano,n03452741 -greenhouse,n03457902 -grille,n03459775 -grocery store,n03461385 -guillotine,n03467068 -hair slide,n03476684 -hair spray,n03476991 -half track,n03478589 -hammer,n03481172 -hamper,n03482405 -hand blower,n03483316 -hand-held computer,n03485407 -handkerchief,n03485794 -hard disc,n03492542 -harmonica,n03494278 -harp,n03495258 -harvester,n03496892 -hatchet,n03498962 -holster,n03527444 -home theater,n03529860 -honeycomb,n03530642 -hook,n03532672 -hoopskirt,n03534580 -horizontal bar,n03535780 -horse cart,n03538406 -hourglass,n03544143 -iPod,n03584254 -iron,n03584829 -jack-o'-lantern,n03590841 -jean,n03594734 -jeep,n03594945 -jersey,n03595614 -jigsaw puzzle,n03598930 -jinrikisha,n03599486 -joystick,n03602883 -kimono,n03617480 -knee pad,n03623198 -knot,n03627232 -lab coat,n03630383 -ladle,n03633091 -lampshade,n03637318 -laptop,n03642806 -lawn mower,n03649909 -lens cap,n03657121 -letter opener,n03658185 -library,n03661043 -lifeboat,n03662601 -lighter,n03666591 -limousine,n03670208 -liner,n03673027 -lipstick,n03676483 -Loafer,n03680355 -lotion,n03690938 -loudspeaker,n03691459 -loupe,n03692522 -lumbermill,n03697007 -magnetic compass,n03706229 -mailbag,n03709823 -mailbox,n03710193 -maillot,n03710637 -tank suit,n03710721 -manhole cover,n03717622 -maraca,n03720891 -marimba,n03721384 -mask,n03724870 -matchstick,n03729826 -maypole,n03733131 -maze,n03733281 -measuring cup,n03733805 -medicine chest,n03742115 -megalith,n03743016 -microphone,n03759954 -microwave,n03761084 -military uniform,n03763968 -milk can,n03764736 -minibus,n03769881 -miniskirt,n03770439 -minivan,n03770679 -missile,n03773504 -mitten,n03775071 -mixing bowl,n03775546 -mobile home,n03776460 -Model T,n03777568 -modem,n03777754 -monastery,n03781244 -monitor,n03782006 -moped,n03785016 -mortar,n03786901 -mortarboard,n03787032 -mosque,n03788195 -mosquito net,n03788365 -motor scooter,n03791053 -mountain bike,n03792782 -mountain tent,n03792972 -mouse,n03793489 -mousetrap,n03794056 -moving van,n03796401 -muzzle,n03803284 -nail,n03804744 -neck brace,n03814639 -necklace,n03814906 -nipple,n03825788 -notebook,n03832673 -obelisk,n03837869 -oboe,n03838899 -ocarina,n03840681 -odometer,n03841143 -oil filter,n03843555 -organ,n03854065 -oscilloscope,n03857828 -overskirt,n03866082 -oxcart,n03868242 -oxygen mask,n03868863 -packet,n03871628 -paddle,n03873416 -paddlewheel,n03874293 -padlock,n03874599 -paintbrush,n03876231 -pajama,n03877472 -palace,n03877845 -panpipe,n03884397 -paper towel,n03887697 -parachute,n03888257 -parallel bars,n03888605 -park bench,n03891251 -parking meter,n03891332 -passenger car,n03895866 -patio,n03899768 -pay-phone,n03902125 -pedestal,n03903868 -pencil box,n03908618 -pencil sharpener,n03908714 -perfume,n03916031 -Petri dish,n03920288 -photocopier,n03924679 -pick,n03929660 -pickelhaube,n03929855 -picket fence,n03930313 -pickup,n03930630 -pier,n03933933 -piggy bank,n03935335 -pill bottle,n03937543 -pillow,n03938244 -ping-pong ball,n03942813 -pinwheel,n03944341 -pirate,n03947888 -pitcher,n03950228 -plane,n03954731 -planetarium,n03956157 -plastic bag,n03958227 -plate rack,n03961711 -plow,n03967562 -plunger,n03970156 -Polaroid camera,n03976467 -pole,n03976657 -police van,n03977966 -poncho,n03980874 -pool table,n03982430 -pop bottle,n03983396 -pot,n03991062 -potter's wheel,n03992509 -power drill,n03995372 -prayer rug,n03998194 -printer,n04004767 -prison,n04005630 -projectile,n04008634 -projector,n04009552 -puck,n04019541 -punching bag,n04023962 -purse,n04026417 -quill,n04033901 -quilt,n04033995 -racer,n04037443 -racket,n04039381 -radiator,n04040759 -radio,n04041544 -radio telescope,n04044716 -rain barrel,n04049303 -recreational vehicle,n04065272 -reel,n04067472 -reflex camera,n04069434 -refrigerator,n04070727 -remote control,n04074963 -restaurant,n04081281 -revolver,n04086273 -rifle,n04090263 -rocking chair,n04099969 -rotisserie,n04111531 -rubber eraser,n04116512 -rugby ball,n04118538 -rule,n04118776 -running shoe,n04120489 -safe,n04125021 -safety pin,n04127249 -saltshaker,n04131690 -sandal,n04133789 -sarong,n04136333 -sax,n04141076 -scabbard,n04141327 -scale,n04141975 -school bus,n04146614 -schooner,n04147183 -scoreboard,n04149813 -screen,n04152593 -screw,n04153751 -screwdriver,n04154565 -seat belt,n04162706 -sewing machine,n04179913 -shield,n04192698 -shoe shop,n04200800 -shoji,n04201297 -shopping basket,n04204238 -shopping cart,n04204347 -shovel,n04208210 -shower cap,n04209133 -shower curtain,n04209239 -ski,n04228054 -ski mask,n04229816 -sleeping bag,n04235860 -slide rule,n04238763 -sliding door,n04239074 -slot,n04243546 -snorkel,n04251144 -snowmobile,n04252077 -snowplow,n04252225 -soap dispenser,n04254120 -soccer ball,n04254680 -sock,n04254777 -solar dish,n04258138 -sombrero,n04259630 -soup bowl,n04263257 -space bar,n04264628 -space heater,n04265275 -space shuttle,n04266014 -spatula,n04270147 -speedboat,n04273569 -spider web,n04275548 -spindle,n04277352 -sports car,n04285008 -spotlight,n04286575 -stage,n04296562 -steam locomotive,n04310018 -steel arch bridge,n04311004 -steel drum,n04311174 -stethoscope,n04317175 -stole,n04325704 -stone wall,n04326547 -stopwatch,n04328186 -stove,n04330267 -strainer,n04332243 -streetcar,n04335435 -stretcher,n04336792 -studio couch,n04344873 -stupa,n04346328 -submarine,n04347754 -suit,n04350905 -sundial,n04355338 -sunglass,n04355933 -sunglasses,n04356056 -sunscreen,n04357314 -suspension bridge,n04366367 -swab,n04367480 -sweatshirt,n04370456 -swimming trunks,n04371430 -swing,n04371774 -switch,n04372370 -syringe,n04376876 -table lamp,n04380533 -tank,n04389033 -tape player,n04392985 -teapot,n04398044 -teddy,n04399382 -television,n04404412 -tennis ball,n04409515 -thatch,n04417672 -theater curtain,n04418357 -thimble,n04423845 -thresher,n04428191 -throne,n04429376 -tile roof,n04435653 -toaster,n04442312 -tobacco shop,n04443257 -toilet seat,n04447861 -torch,n04456115 -totem pole,n04458633 -tow truck,n04461696 -toyshop,n04462240 -tractor,n04465501 -trailer truck,n04467665 -tray,n04476259 -trench coat,n04479046 -tricycle,n04482393 -trimaran,n04483307 -tripod,n04485082 -triumphal arch,n04486054 -trolleybus,n04487081 -trombone,n04487394 -tub,n04493381 -turnstile,n04501370 -typewriter keyboard,n04505470 -umbrella,n04507155 -unicycle,n04509417 -upright,n04515003 -vacuum,n04517823 -vase,n04522168 -vault,n04523525 -velvet,n04525038 -vending machine,n04525305 -vestment,n04532106 -viaduct,n04532670 -violin,n04536866 -volleyball,n04540053 -waffle iron,n04542943 -wall clock,n04548280 -wallet,n04548362 -wardrobe,n04550184 -warplane,n04552348 -washbasin,n04553703 -washer,n04554684 -water bottle,n04557648 -water jug,n04560804 -water tower,n04562935 -whiskey jug,n04579145 -whistle,n04579432 -wig,n04584207 -window screen,n04589890 -window shade,n04590129 -Windsor tie,n04591157 -wine bottle,n04591713 -wing,n04592741 -wok,n04596742 -wooden spoon,n04597913 -wool,n04599235 -worm fence,n04604644 -wreck,n04606251 -yawl,n04612504 -yurt,n04613696 -web site,n06359193 -comic book,n06596364 -crossword puzzle,n06785654 -street sign,n06794110 -traffic light,n06874185 -book jacket,n07248320 -menu,n07565083 -plate,n07579787 -guacamole,n07583066 -consomme,n07584110 -hot pot,n07590611 -trifle,n07613480 -ice cream,n07614500 -ice lolly,n07615774 -French loaf,n07684084 -bagel,n07693725 -pretzel,n07695742 -cheeseburger,n07697313 -hotdog,n07697537 -mashed potato,n07711569 -head cabbage,n07714571 -broccoli,n07714990 -cauliflower,n07715103 -zucchini,n07716358 -spaghetti squash,n07716906 -acorn squash,n07717410 -butternut squash,n07717556 -cucumber,n07718472 -artichoke,n07718747 -bell pepper,n07720875 -cardoon,n07730033 -mushroom,n07734744 -Granny Smith,n07742313 -strawberry,n07745940 -orange,n07747607 -lemon,n07749582 -fig,n07753113 -pineapple,n07753275 -banana,n07753592 -jackfruit,n07754684 -custard apple,n07760859 -pomegranate,n07768694 -hay,n07802026 -carbonara,n07831146 -chocolate sauce,n07836838 -dough,n07860988 -meat loaf,n07871810 -pizza,n07873807 -potpie,n07875152 -burrito,n07880968 -red wine,n07892512 -espresso,n07920052 -cup,n07930864 -eggnog,n07932039 -alp,n09193705 -bubble,n09229709 -cliff,n09246464 -coral reef,n09256479 -geyser,n09288635 -lakeside,n09332890 -promontory,n09399592 -sandbar,n09421951 -seashore,n09428293 -valley,n09468604 -volcano,n09472597 -ballplayer,n09835506 -groom,n10148035 -scuba diver,n10565667 -rapeseed,n11879895 -daisy,n11939491 -yellow lady's slipper,n12057211 -corn,n12144580 -acorn,n12267677 -hip,n12620546 -buckeye,n12768682 -coral fungus,n12985857 -agaric,n12998815 -gyromitra,n13037406 -stinkhorn,n13040303 -earthstar,n13044778 -hen-of-the-woods,n13052670 -bolete,n13054560 -ear,n13133613 -toilet tissue,n15075141 diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py deleted file mode 100644 index 5e2db41e1d0..00000000000 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ /dev/null @@ -1,223 +0,0 @@ -import enum -import pathlib -import re - -from typing import Any, BinaryIO, cast, Dict, Iterator, List, Match, Optional, Tuple, Union - -from torchdata.datapipes.iter import ( - Demultiplexer, - Enumerator, - Filter, - IterDataPipe, - IterKeyZipper, - LineReader, - Mapper, - TarArchiveLoader, -) -from torchdata.datapipes.map import IterToMapConverter -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, ManualDownloadResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - read_categories_file, - read_mat, -) - -from .._api import register_dataset, register_info - -NAME = "imagenet" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - categories, wnids = zip(*read_categories_file(NAME)) - return dict(categories=categories, wnids=wnids) - - -class ImageNetResource(ManualDownloadResource): - def __init__(self, **kwargs: Any) -> None: - super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) - - -class ImageNetDemux(enum.IntEnum): - META = 0 - LABEL = 1 - - -class CategoryAndWordNetIDExtractor(IterDataPipe): - # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 - # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment - _WNID_MAP = { - "n03126707": "construction crane", - "n03710721": "tank suit", - } - - def __init__(self, datapipe: IterDataPipe[Tuple[str, BinaryIO]]) -> None: - self.datapipe = datapipe - - def __iter__(self) -> Iterator[Tuple[str, str]]: - for _, stream in self.datapipe: - synsets = read_mat(stream, squeeze_me=True)["synsets"] - for _, wnid, category, _, num_children, *_ in synsets: - if num_children > 0: - # we are looking at a superclass that has no direct instance - continue - - yield self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid - - -@register_dataset(NAME) -class ImageNet(Dataset): - """ - - **homepage**: https://www.image-net.org/ - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) - - info = _info() - categories, wnids = info["categories"], info["wnids"] - self._categories = categories - self._wnids = wnids - self._wnid_to_category = dict(zip(wnids, categories)) - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _IMAGES_CHECKSUMS = { - "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", - "val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0", - "test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", - } - - def _resources(self) -> List[OnlineResource]: - name = "test_v10102019" if self._split == "test" else self._split - images = ImageNetResource( - file_name=f"ILSVRC2012_img_{name}.tar", - sha256=self._IMAGES_CHECKSUMS[name], - ) - resources: List[OnlineResource] = [images] - - if self._split == "val": - devkit = ImageNetResource( - file_name="ILSVRC2012_devkit_t12.tar.gz", - sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", - ) - resources.append(devkit) - - return resources - - _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG") - - def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: - path = pathlib.Path(data[0]) - wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) - return (label, wnid), data - - def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: - return None, data - - def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: - return { - "meta.mat": ImageNetDemux.META, - "ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL, - }.get(pathlib.Path(data[0]).name) - - _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG") - - def _val_test_image_key(self, path: pathlib.Path) -> int: - return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index] - - def _prepare_val_data( - self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] - ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: - label_data, image_data = data - _, wnid = label_data - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) - return (label, wnid), image_data - - def _prepare_sample( - self, - data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]], - ) -> Dict[str, Any]: - label_data, (path, buffer) = data - - return dict( - dict(zip(("label", "wnid"), label_data if label_data else (None, None))), - path=path, - image=EncodedImage.from_file(buffer), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - if self._split in {"train", "test"}: - dp = resource_dps[0] - - # the train archive is a tar of tars - if self._split == "train": - dp = TarArchiveLoader(dp) - - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data) - else: # config.split == "val": - images_dp, devkit_dp = resource_dps - - meta_dp, label_dp = Demultiplexer( - devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - - # We cannot use self._wnids here, since we use a different order than the dataset - meta_dp = CategoryAndWordNetIDExtractor(meta_dp) - wnid_dp = Mapper(meta_dp, getitem(1)) - wnid_dp = Enumerator(wnid_dp, 1) - wnid_map = IterToMapConverter(wnid_dp) - - label_dp = LineReader(label_dp, decode=True, return_path=False) - label_dp = Mapper(label_dp, int) - label_dp = Mapper(label_dp, wnid_map.__getitem__) - label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) - label_dp = hint_shuffling(label_dp) - label_dp = hint_sharding(label_dp) - - dp = IterKeyZipper( - label_dp, - images_dp, - key_fn=getitem(0), - ref_key_fn=path_accessor(self._val_test_image_key), - buffer_size=INFINITE_BUFFER_SIZE, - ) - dp = Mapper(dp, self._prepare_val_data) - - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - "train": 1_281_167, - "val": 50_000, - "test": 100_000, - }[self._split] - - def _filter_meta(self, data: Tuple[str, Any]) -> bool: - return self._classifiy_devkit(data) == ImageNetDemux.META - - def _generate_categories(self) -> List[Tuple[str, ...]]: - self._split = "val" - resources = self._resources() - - devkit_dp = resources[1].load(self._root) - meta_dp = Filter(devkit_dp, self._filter_meta) - meta_dp = CategoryAndWordNetIDExtractor(meta_dp) - - categories_and_wnids = cast(List[Tuple[str, ...]], list(meta_dp)) - categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) - return categories_and_wnids diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py deleted file mode 100644 index 8f22a33ae01..00000000000 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ /dev/null @@ -1,419 +0,0 @@ -import abc -import functools -import operator -import pathlib -import string -from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence, Tuple, Union - -import torch -from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper -from torchvision.datapoints import Image -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE -from torchvision.prototype.utils._internal import fromfile - -from .._api import register_dataset, register_info - - -prod = functools.partial(functools.reduce, operator.mul) - - -class MNISTFileReader(IterDataPipe[torch.Tensor]): - _DTYPE_MAP = { - 8: torch.uint8, - 9: torch.int8, - 11: torch.int16, - 12: torch.int32, - 13: torch.float32, - 14: torch.float64, - } - - def __init__( - self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, start: Optional[int], stop: Optional[int] - ) -> None: - self.datapipe = datapipe - self.start = start - self.stop = stop - - def __iter__(self) -> Iterator[torch.Tensor]: - for _, file in self.datapipe: - try: - read = functools.partial(fromfile, file, byte_order="big") - - magic = int(read(dtype=torch.int32, count=1)) - dtype = self._DTYPE_MAP[magic // 256] - ndim = magic % 256 - 1 - - num_samples = int(read(dtype=torch.int32, count=1)) - shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else [] - count = prod(shape) if shape else 1 - - start = self.start or 0 - stop = min(self.stop, num_samples) if self.stop else num_samples - - if start: - num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 - file.seek(num_bytes_per_value * count * start, 1) - - for _ in range(stop - start): - yield read(dtype=dtype, count=count).reshape(shape) - finally: - file.close() - - -class _MNISTBase(Dataset): - _URL_BASE: Union[str, Sequence[str]] - - @abc.abstractmethod - def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: - pass - - def _resources(self) -> List[OnlineResource]: - (images_file, images_sha256), ( - labels_file, - labels_sha256, - ) = self._files_and_checksums() - - url_bases = self._URL_BASE - if isinstance(url_bases, str): - url_bases = (url_bases,) - - images_urls = [f"{url_base}/{images_file}" for url_base in url_bases] - images = HttpResource(images_urls[0], sha256=images_sha256, mirrors=images_urls[1:]) - - labels_urls = [f"{url_base}/{labels_file}" for url_base in url_bases] - labels = HttpResource(labels_urls[0], sha256=labels_sha256, mirrors=labels_urls[1:]) - - return [images, labels] - - def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: - return None, None - - _categories: List[str] - - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: - image, label = data - return dict( - image=Image(image), - label=Label(label, dtype=torch.int64, categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - images_dp, labels_dp = resource_dps - start, stop = self.start_and_stop() - - images_dp = Decompressor(images_dp) - images_dp = MNISTFileReader(images_dp, start=start, stop=stop) - - labels_dp = Decompressor(labels_dp) - labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop) - - dp = Zipper(images_dp, labels_dp) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - -@register_info("mnist") -def _mnist_info() -> Dict[str, Any]: - return dict( - categories=[str(label) for label in range(10)], - ) - - -@register_dataset("mnist") -class MNIST(_MNISTBase): - """ - - **homepage**: http://yann.lecun.com/exdb/mnist - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test")) - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _URL_BASE: Union[str, Sequence[str]] = ( - "http://yann.lecun.com/exdb/mnist", - "https://ossci-datasets.s3.amazonaws.com/mnist", - ) - _CHECKSUMS = { - "train-images-idx3-ubyte.gz": "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609", - "train-labels-idx1-ubyte.gz": "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c", - "t10k-images-idx3-ubyte.gz": "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6", - "t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6", - } - - def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "train" if self._split == "train" else "t10k" - images_file = f"{prefix}-images-idx3-ubyte.gz" - labels_file = f"{prefix}-labels-idx1-ubyte.gz" - return (images_file, self._CHECKSUMS[images_file]), ( - labels_file, - self._CHECKSUMS[labels_file], - ) - - _categories = _mnist_info()["categories"] - - def __len__(self) -> int: - return 60_000 if self._split == "train" else 10_000 - - -@register_info("fashionmnist") -def _fashionmnist_info() -> Dict[str, Any]: - return dict( - categories=[ - "T-shirt/top", - "Trouser", - "Pullover", - "Dress", - "Coat", - "Sandal", - "Shirt", - "Sneaker", - "Bag", - "Ankle boot", - ], - ) - - -@register_dataset("fashionmnist") -class FashionMNIST(MNIST): - """ - - **homepage**: https://github.com/zalandoresearch/fashion-mnist - """ - - _URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com" - _CHECKSUMS = { - "train-images-idx3-ubyte.gz": "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84", - "train-labels-idx1-ubyte.gz": "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845", - "t10k-images-idx3-ubyte.gz": "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073", - "t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5", - } - - _categories = _fashionmnist_info()["categories"] - - -@register_info("kmnist") -def _kmnist_info() -> Dict[str, Any]: - return dict( - categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], - ) - - -@register_dataset("kmnist") -class KMNIST(MNIST): - """ - - **homepage**: http://codh.rois.ac.jp/kmnist/index.html.en - """ - - _URL_BASE = "http://codh.rois.ac.jp/kmnist/dataset/kmnist" - _CHECKSUMS = { - "train-images-idx3-ubyte.gz": "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4", - "train-labels-idx1-ubyte.gz": "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17", - "t10k-images-idx3-ubyte.gz": "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5", - "t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c", - } - - _categories = _kmnist_info()["categories"] - - -@register_info("emnist") -def _emnist_info() -> Dict[str, Any]: - return dict( - categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), - ) - - -@register_dataset("emnist") -class EMNIST(_MNISTBase): - """ - - **homepage**: https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - image_set: str = "Balanced", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test")) - self._image_set = self._verify_str_arg( - image_set, "image_set", ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST") - ) - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST" - - def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = f"emnist-{self._image_set.replace('_', '').lower()}-{self._split}" - images_file = f"{prefix}-images-idx3-ubyte.gz" - labels_file = f"{prefix}-labels-idx1-ubyte.gz" - # Since EMNIST provides the data files inside an archive, we don't need to provide checksums for them - return (images_file, ""), (labels_file, "") - - def _resources(self) -> List[OnlineResource]: - return [ - HttpResource( - f"{self._URL_BASE}/emnist-gzip.zip", - sha256="909a2a39c5e86bdd7662425e9b9c4a49bb582bf8d0edad427f3c3a9d0c6f7259", - ) - ] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - (images_file, _), (labels_file, _) = self._files_and_checksums() - if path.name == images_file: - return 0 - elif path.name == labels_file: - return 1 - else: - return None - - _categories = _emnist_info()["categories"] - - _LABEL_OFFSETS = { - 38: 1, - 39: 1, - 40: 1, - 41: 1, - 42: 1, - 43: 6, - 44: 8, - 45: 8, - 46: 9, - } - - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: - # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). - # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, - # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create these gaps. For - # example, since there is no 'c', 'd' corresponds to - # label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing), - # and at the same time corresponds to - # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) - # in self._categories. Thus, we need to add 1 to the label to correct this. - if self._image_set in ("Balanced", "By_Merge"): - image, label = data - label += self._LABEL_OFFSETS.get(int(label), 0) - data = (image, label) - return super()._prepare_sample(data) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] - images_dp, labels_dp = Demultiplexer( - archive_dp, - 2, - self._classify_archive, - drop_none=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - return super()._datapipe([images_dp, labels_dp]) - - def __len__(self) -> int: - return { - ("train", "Balanced"): 112_800, - ("train", "By_Merge"): 697_932, - ("train", "By_Class"): 697_932, - ("train", "Letters"): 124_800, - ("train", "Digits"): 240_000, - ("train", "MNIST"): 60_000, - ("test", "Balanced"): 18_800, - ("test", "By_Merge"): 116_323, - ("test", "By_Class"): 116_323, - ("test", "Letters"): 20_800, - ("test", "Digits"): 40_000, - ("test", "MNIST"): 10_000, - }[(self._split, self._image_set)] - - -@register_info("qmnist") -def _qmnist_info() -> Dict[str, Any]: - return dict( - categories=[str(label) for label in range(10)], - ) - - -@register_dataset("qmnist") -class QMNIST(_MNISTBase): - """ - - **homepage**: https://github.com/facebookresearch/qmnist - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "test", "test10k", "test50k", "nist")) - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master" - _CHECKSUMS = { - "qmnist-train-images-idx3-ubyte.gz": "9e26a7bf1683614e065d7b76460ccd52807165b3f22561fb782bd9f38c52b51d", - "qmnist-train-labels-idx2-int.gz": "2c05dc77f6b916b38e455e97ab129a42a444f3dbef09b278a366f82904e0dd9f", - "qmnist-test-images-idx3-ubyte.gz": "43fc22bf7498b8fc98de98369d72f752d0deabc280a43a7bcc364ab19e57b375", - "qmnist-test-labels-idx2-int.gz": "9fbcbe594c3766fdf4f0b15c5165dc0d1e57ac604e01422608bb72c906030d06", - "xnist-images-idx3-ubyte.xz": "f075553993026d4359ded42208eff77a1941d3963c1eff49d6015814f15f0984", - "xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f", - } - - def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "xnist" if self._split == "nist" else f"qmnist-{'train' if self._split == 'train' else 'test'}" - suffix = "xz" if self._split == "nist" else "gz" - images_file = f"{prefix}-images-idx3-ubyte.{suffix}" - labels_file = f"{prefix}-labels-idx2-int.{suffix}" - return (images_file, self._CHECKSUMS[images_file]), ( - labels_file, - self._CHECKSUMS[labels_file], - ) - - def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: - start: Optional[int] - stop: Optional[int] - if self._split == "test10k": - start = 0 - stop = 10000 - elif self._split == "test50k": - start = 10000 - stop = None - else: - start = stop = None - - return start, stop - - _categories = _emnist_info()["categories"] - - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: - image, ann = data - label, *extra_anns = ann - sample = super()._prepare_sample((image, label)) - - sample.update( - dict( - zip( - ("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"), - [int(value) for value in extra_anns[:5]], - ) - ) - ) - sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]]))) - return sample - - def __len__(self) -> int: - return { - "train": 60_000, - "test": 60_000, - "test10k": 10_000, - "test50k": 50_000, - "nist": 402_953, - }[self._split] diff --git a/torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories b/torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories deleted file mode 100644 index 36d29465b04..00000000000 --- a/torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories +++ /dev/null @@ -1,37 +0,0 @@ -Abyssinian -American Bulldog -American Pit Bull Terrier -Basset Hound -Beagle -Bengal -Birman -Bombay -Boxer -British Shorthair -Chihuahua -Egyptian Mau -English Cocker Spaniel -English Setter -German Shorthaired -Great Pyrenees -Havanese -Japanese Chin -Keeshond -Leonberger -Maine Coon -Miniature Pinscher -Newfoundland -Persian -Pomeranian -Pug -Ragdoll -Russian Blue -Saint Bernard -Samoyed -Scottish Terrier -Shiba Inu -Siamese -Sphynx -Staffordshire Bull Terrier -Wheaten Terrier -Yorkshire Terrier diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py deleted file mode 100644 index fbc7d30c292..00000000000 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ /dev/null @@ -1,146 +0,0 @@ -import enum -import pathlib -from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - path_comparator, - read_categories_file, -) - -from .._api import register_dataset, register_info - - -NAME = "oxford-iiit-pet" - - -class OxfordIIITPetDemux(enum.IntEnum): - SPLIT_AND_CLASSIFICATION = 0 - SEGMENTATIONS = 1 - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class OxfordIIITPet(Dataset): - """Oxford IIIT Pet Dataset - homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", - """ - - def __init__( - self, root: Union[str, pathlib.Path], *, split: str = "trainval", skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", {"trainval", "test"}) - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - images = HttpResource( - "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", - sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d", - preprocess="decompress", - ) - anns = HttpResource( - "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", - sha256="52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91", - preprocess="decompress", - ) - return [images, anns] - - def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]: - return { - "annotations": OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION, - "trimaps": OxfordIIITPetDemux.SEGMENTATIONS, - }.get(pathlib.Path(data[0]).parent.name) - - def _filter_images(self, data: Tuple[str, Any]) -> bool: - return pathlib.Path(data[0]).suffix == ".jpg" - - def _filter_segmentations(self, data: Tuple[str, Any]) -> bool: - return not pathlib.Path(data[0]).name.startswith(".") - - def _prepare_sample( - self, data: Tuple[Tuple[Dict[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]] - ) -> Dict[str, Any]: - ann_data, image_data = data - classification_data, segmentation_data = ann_data - segmentation_path, segmentation_buffer = segmentation_data - image_path, image_buffer = image_data - - return dict( - label=Label(int(classification_data["label"]) - 1, categories=self._categories), - species="cat" if classification_data["species"] == "1" else "dog", - segmentation_path=segmentation_path, - segmentation=EncodedImage.from_file(segmentation_buffer), - image_path=image_path, - image=EncodedImage.from_file(image_buffer), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - images_dp, anns_dp = resource_dps - - images_dp = Filter(images_dp, self._filter_images) - - split_and_classification_dp, segmentations_dp = Demultiplexer( - anns_dp, - 2, - self._classify_anns, - drop_none=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - split_and_classification_dp = Filter(split_and_classification_dp, path_comparator("name", f"{self._split}.txt")) - split_and_classification_dp = CSVDictParser( - split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" " - ) - split_and_classification_dp = hint_shuffling(split_and_classification_dp) - split_and_classification_dp = hint_sharding(split_and_classification_dp) - - segmentations_dp = Filter(segmentations_dp, self._filter_segmentations) - - anns_dp = IterKeyZipper( - split_and_classification_dp, - segmentations_dp, - key_fn=getitem("image_id"), - ref_key_fn=path_accessor("stem"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - - dp = IterKeyZipper( - anns_dp, - images_dp, - key_fn=getitem(0, "image_id"), - ref_key_fn=path_accessor("stem"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: - return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - dp = resources[1].load(self._root) - dp = Filter(dp, self._filter_split_and_classification_anns) - dp = Filter(dp, path_comparator("name", "trainval.txt")) - dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ") - - raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp} - raw_categories, _ = zip( - *sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1])) - ) - return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories] - - def __len__(self) -> int: - return 3_680 if self._split == "trainval" else 3_669 diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py deleted file mode 100644 index 4de5ae2765b..00000000000 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ /dev/null @@ -1,129 +0,0 @@ -import io -import pathlib -from collections import namedtuple -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper -from torchvision.datapoints import Image -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling - -from .._api import register_dataset, register_info - - -NAME = "pcam" - - -class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): - def __init__( - self, - datapipe: IterDataPipe[Tuple[str, io.IOBase]], - key: Optional[str] = None, # Note: this key thing might be very specific to the PCAM dataset - ) -> None: - self.datapipe = datapipe - self.key = key - - def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: - import h5py - - for _, handle in self.datapipe: - try: - with h5py.File(handle) as data: - if self.key is not None: - data = data[self.key] - yield from data - finally: - handle.close() - - -_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256")) - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=["0", "1"]) - - -@register_dataset(NAME) -class PCAM(Dataset): - # TODO write proper docstring - """PCAM Dataset - - homepage="https://github.com/basveeling/pcam" - """ - - def __init__( - self, root: Union[str, pathlib.Path], split: str = "train", *, skip_integrity_check: bool = False - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("h5py",)) - - _RESOURCES = { - "train": ( - _Resource( # Images - file_name="camelyonpatch_level_2_split_train_x.h5.gz", - gdrive_id="1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2", - sha256="d619e741468a7ab35c7e4a75e6821b7e7e6c9411705d45708f2a0efc8960656c", - ), - _Resource( # Targets - file_name="camelyonpatch_level_2_split_train_y.h5.gz", - gdrive_id="1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG", - sha256="b74126d2c01b20d3661f9b46765d29cf4e4fba6faba29c8e0d09d406331ab75a", - ), - ), - "test": ( - _Resource( # Images - file_name="camelyonpatch_level_2_split_test_x.h5.gz", - gdrive_id="1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_", - sha256="79174c2201ad521602a5888be8f36ee10875f37403dd3f2086caf2182ef87245", - ), - _Resource( # Targets - file_name="camelyonpatch_level_2_split_test_y.h5.gz", - gdrive_id="17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP", - sha256="0a522005fccc8bbd04c5a117bfaf81d8da2676f03a29d7499f71d0a0bd6068ef", - ), - ), - "val": ( - _Resource( # Images - file_name="camelyonpatch_level_2_split_valid_x.h5.gz", - gdrive_id="1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3", - sha256="f82ee1670d027b4ec388048d9eabc2186b77c009655dae76d624c0ecb053ccb2", - ), - _Resource( # Targets - file_name="camelyonpatch_level_2_split_valid_y.h5.gz", - gdrive_id="1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO", - sha256="ce1ae30f08feb468447971cfd0472e7becd0ad96d877c64120c72571439ae48c", - ), - ), - } - - def _resources(self) -> List[OnlineResource]: - return [ # = [images resource, targets resource] - GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress") - for file_name, gdrive_id, sha256 in self._RESOURCES[self._split] - ] - - def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: - image, target = data # They're both numpy arrays at this point - - return { - "image": Image(image.transpose(2, 0, 1)), - "label": Label(target.item(), categories=self._categories), - } - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - - images_dp, targets_dp = resource_dps - - images_dp = PCAMH5Reader(images_dp, key="x") - targets_dp = PCAMH5Reader(targets_dp, key="y") - - dp = Zipper(images_dp, targets_dp) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 262_144 if self._split == "train" else 32_768 diff --git a/torchvision/prototype/datasets/_builtin/sbd.categories b/torchvision/prototype/datasets/_builtin/sbd.categories deleted file mode 100644 index 8420ab35ede..00000000000 --- a/torchvision/prototype/datasets/_builtin/sbd.categories +++ /dev/null @@ -1,20 +0,0 @@ -aeroplane -bicycle -bird -boat -bottle -bus -car -cat -chair -cow -diningtable -dog -horse -motorbike -person -pottedplant -sheep -sofa -train -tvmonitor diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py deleted file mode 100644 index 97986b58b5d..00000000000 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ /dev/null @@ -1,165 +0,0 @@ -import pathlib -import re -from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - path_comparator, - read_categories_file, - read_mat, -) - -from .._api import register_dataset, register_info - -NAME = "sbd" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class SBD(Dataset): - """ - - **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html - - **dependencies**: - - <scipy `https://scipy.org`>_ - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval")) - - self._categories = _info()["categories"] - - super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - resources = [ - HttpResource( - "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", - sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", - ) - ] - if self._split == "train_noval": - resources.append( - HttpResource( - "http://home.bharathh.info/pubs/codes/SBD/train_noval.txt", - sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432", - ) - ) - return resources # type: ignore[return-value] - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - path = pathlib.Path(data[0]) - parent, grandparent, *_ = path.parents - - if grandparent.name == "dataset": - if parent.name == "img": - return 0 - elif parent.name == "cls": - return 1 - - if parent.name == "dataset" and self._split != "train_noval": - return 2 - - return None - - def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: - split_and_image_data, ann_data = data - _, image_data = split_and_image_data - image_path, image_buffer = image_data - ann_path, ann_buffer = ann_data - - anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"] - - return dict( - image_path=image_path, - image=EncodedImage.from_file(image_buffer), - ann_path=ann_path, - # the boundaries are stored in sparse CSC format, which is not supported by PyTorch - boundaries=torch.as_tensor( - np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()]) - ), - segmentation=torch.as_tensor(anns["Segmentation"].item()), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - if self._split == "train_noval": - archive_dp, split_dp = resource_dps - images_dp, anns_dp = Demultiplexer( - archive_dp, - 2, - self._classify_archive, - buffer_size=INFINITE_BUFFER_SIZE, - drop_none=True, - ) - else: - archive_dp = resource_dps[0] - images_dp, anns_dp, split_dp = Demultiplexer( - archive_dp, - 3, - self._classify_archive, - buffer_size=INFINITE_BUFFER_SIZE, - drop_none=True, - ) - - split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) - split_dp = LineReader(split_dp, decode=True) - split_dp = hint_shuffling(split_dp) - split_dp = hint_sharding(split_dp) - - dp = split_dp - for level, data_dp in enumerate((images_dp, anns_dp)): - dp = IterKeyZipper( - dp, - data_dp, - key_fn=getitem(*[0] * level, 1), - ref_key_fn=path_accessor("stem"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - "train": 8_498, - "val": 2_857, - "train_noval": 5_623, - }[self._split] - - def _generate_categories(self) -> Tuple[str, ...]: - resources = self._resources() - - dp = resources[0].load(self._root) - dp = Filter(dp, path_comparator("name", "category_names.m")) - dp = LineReader(dp) - dp = Mapper(dp, bytes.decode, input_col=1) - lines = tuple(zip(*iter(dp)))[1] - - pattern = re.compile(r"\s*'(?P<category>\w+)';\s*%(?P<label>\d+)") - categories_and_labels = cast( - List[Tuple[str, ...]], - [ - pattern.match(line).groups() # type: ignore[union-attr] - # the first and last line contain no information - for line in lines[1:-1] - ], - ) - categories_and_labels.sort(key=lambda category_and_label: int(category_and_label[1])) - categories, _ = zip(*categories_and_labels) - - return categories diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py deleted file mode 100644 index 92e1b93b410..00000000000 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ /dev/null @@ -1,55 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Tuple, Union - -import torch -from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper -from torchvision.datapoints import Image -from torchvision.prototype.datapoints import OneHotLabel -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling - -from .._api import register_dataset, register_info - -NAME = "semeion" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=[str(i) for i in range(10)]) - - -@register_dataset(NAME) -class SEMEION(Dataset): - """Semeion dataset - homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", - """ - - def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: - - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check) - - def _resources(self) -> List[OnlineResource]: - data = HttpResource( - "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data", - sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1", - ) - return [data] - - def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: - image_data, label_data = data[:256], data[256:-1] - - return dict( - image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.float).reshape(16, 16)), - label=OneHotLabel([int(label) for label in label_data], categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = CSVParser(dp, delimiter=" ") - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 1_593 diff --git a/torchvision/prototype/datasets/_builtin/stanford-cars.categories b/torchvision/prototype/datasets/_builtin/stanford-cars.categories deleted file mode 100644 index e54040f7779..00000000000 --- a/torchvision/prototype/datasets/_builtin/stanford-cars.categories +++ /dev/null @@ -1,196 +0,0 @@ -AM General Hummer SUV 2000 -Acura RL Sedan 2012 -Acura TL Sedan 2012 -Acura TL Type-S 2008 -Acura TSX Sedan 2012 -Acura Integra Type R 2001 -Acura ZDX Hatchback 2012 -Aston Martin V8 Vantage Convertible 2012 -Aston Martin V8 Vantage Coupe 2012 -Aston Martin Virage Convertible 2012 -Aston Martin Virage Coupe 2012 -Audi RS 4 Convertible 2008 -Audi A5 Coupe 2012 -Audi TTS Coupe 2012 -Audi R8 Coupe 2012 -Audi V8 Sedan 1994 -Audi 100 Sedan 1994 -Audi 100 Wagon 1994 -Audi TT Hatchback 2011 -Audi S6 Sedan 2011 -Audi S5 Convertible 2012 -Audi S5 Coupe 2012 -Audi S4 Sedan 2012 -Audi S4 Sedan 2007 -Audi TT RS Coupe 2012 -BMW ActiveHybrid 5 Sedan 2012 -BMW 1 Series Convertible 2012 -BMW 1 Series Coupe 2012 -BMW 3 Series Sedan 2012 -BMW 3 Series Wagon 2012 -BMW 6 Series Convertible 2007 -BMW X5 SUV 2007 -BMW X6 SUV 2012 -BMW M3 Coupe 2012 -BMW M5 Sedan 2010 -BMW M6 Convertible 2010 -BMW X3 SUV 2012 -BMW Z4 Convertible 2012 -Bentley Continental Supersports Conv. Convertible 2012 -Bentley Arnage Sedan 2009 -Bentley Mulsanne Sedan 2011 -Bentley Continental GT Coupe 2012 -Bentley Continental GT Coupe 2007 -Bentley Continental Flying Spur Sedan 2007 -Bugatti Veyron 16.4 Convertible 2009 -Bugatti Veyron 16.4 Coupe 2009 -Buick Regal GS 2012 -Buick Rainier SUV 2007 -Buick Verano Sedan 2012 -Buick Enclave SUV 2012 -Cadillac CTS-V Sedan 2012 -Cadillac SRX SUV 2012 -Cadillac Escalade EXT Crew Cab 2007 -Chevrolet Silverado 1500 Hybrid Crew Cab 2012 -Chevrolet Corvette Convertible 2012 -Chevrolet Corvette ZR1 2012 -Chevrolet Corvette Ron Fellows Edition Z06 2007 -Chevrolet Traverse SUV 2012 -Chevrolet Camaro Convertible 2012 -Chevrolet HHR SS 2010 -Chevrolet Impala Sedan 2007 -Chevrolet Tahoe Hybrid SUV 2012 -Chevrolet Sonic Sedan 2012 -Chevrolet Express Cargo Van 2007 -Chevrolet Avalanche Crew Cab 2012 -Chevrolet Cobalt SS 2010 -Chevrolet Malibu Hybrid Sedan 2010 -Chevrolet TrailBlazer SS 2009 -Chevrolet Silverado 2500HD Regular Cab 2012 -Chevrolet Silverado 1500 Classic Extended Cab 2007 -Chevrolet Express Van 2007 -Chevrolet Monte Carlo Coupe 2007 -Chevrolet Malibu Sedan 2007 -Chevrolet Silverado 1500 Extended Cab 2012 -Chevrolet Silverado 1500 Regular Cab 2012 -Chrysler Aspen SUV 2009 -Chrysler Sebring Convertible 2010 -Chrysler Town and Country Minivan 2012 -Chrysler 300 SRT-8 2010 -Chrysler Crossfire Convertible 2008 -Chrysler PT Cruiser Convertible 2008 -Daewoo Nubira Wagon 2002 -Dodge Caliber Wagon 2012 -Dodge Caliber Wagon 2007 -Dodge Caravan Minivan 1997 -Dodge Ram Pickup 3500 Crew Cab 2010 -Dodge Ram Pickup 3500 Quad Cab 2009 -Dodge Sprinter Cargo Van 2009 -Dodge Journey SUV 2012 -Dodge Dakota Crew Cab 2010 -Dodge Dakota Club Cab 2007 -Dodge Magnum Wagon 2008 -Dodge Challenger SRT8 2011 -Dodge Durango SUV 2012 -Dodge Durango SUV 2007 -Dodge Charger Sedan 2012 -Dodge Charger SRT-8 2009 -Eagle Talon Hatchback 1998 -FIAT 500 Abarth 2012 -FIAT 500 Convertible 2012 -Ferrari FF Coupe 2012 -Ferrari California Convertible 2012 -Ferrari 458 Italia Convertible 2012 -Ferrari 458 Italia Coupe 2012 -Fisker Karma Sedan 2012 -Ford F-450 Super Duty Crew Cab 2012 -Ford Mustang Convertible 2007 -Ford Freestar Minivan 2007 -Ford Expedition EL SUV 2009 -Ford Edge SUV 2012 -Ford Ranger SuperCab 2011 -Ford GT Coupe 2006 -Ford F-150 Regular Cab 2012 -Ford F-150 Regular Cab 2007 -Ford Focus Sedan 2007 -Ford E-Series Wagon Van 2012 -Ford Fiesta Sedan 2012 -GMC Terrain SUV 2012 -GMC Savana Van 2012 -GMC Yukon Hybrid SUV 2012 -GMC Acadia SUV 2012 -GMC Canyon Extended Cab 2012 -Geo Metro Convertible 1993 -HUMMER H3T Crew Cab 2010 -HUMMER H2 SUT Crew Cab 2009 -Honda Odyssey Minivan 2012 -Honda Odyssey Minivan 2007 -Honda Accord Coupe 2012 -Honda Accord Sedan 2012 -Hyundai Veloster Hatchback 2012 -Hyundai Santa Fe SUV 2012 -Hyundai Tucson SUV 2012 -Hyundai Veracruz SUV 2012 -Hyundai Sonata Hybrid Sedan 2012 -Hyundai Elantra Sedan 2007 -Hyundai Accent Sedan 2012 -Hyundai Genesis Sedan 2012 -Hyundai Sonata Sedan 2012 -Hyundai Elantra Touring Hatchback 2012 -Hyundai Azera Sedan 2012 -Infiniti G Coupe IPL 2012 -Infiniti QX56 SUV 2011 -Isuzu Ascender SUV 2008 -Jaguar XK XKR 2012 -Jeep Patriot SUV 2012 -Jeep Wrangler SUV 2012 -Jeep Liberty SUV 2012 -Jeep Grand Cherokee SUV 2012 -Jeep Compass SUV 2012 -Lamborghini Reventon Coupe 2008 -Lamborghini Aventador Coupe 2012 -Lamborghini Gallardo LP 570-4 Superleggera 2012 -Lamborghini Diablo Coupe 2001 -Land Rover Range Rover SUV 2012 -Land Rover LR2 SUV 2012 -Lincoln Town Car Sedan 2011 -MINI Cooper Roadster Convertible 2012 -Maybach Landaulet Convertible 2012 -Mazda Tribute SUV 2011 -McLaren MP4-12C Coupe 2012 -Mercedes-Benz 300-Class Convertible 1993 -Mercedes-Benz C-Class Sedan 2012 -Mercedes-Benz SL-Class Coupe 2009 -Mercedes-Benz E-Class Sedan 2012 -Mercedes-Benz S-Class Sedan 2012 -Mercedes-Benz Sprinter Van 2012 -Mitsubishi Lancer Sedan 2012 -Nissan Leaf Hatchback 2012 -Nissan NV Passenger Van 2012 -Nissan Juke Hatchback 2012 -Nissan 240SX Coupe 1998 -Plymouth Neon Coupe 1999 -Porsche Panamera Sedan 2012 -Ram C/V Cargo Van Minivan 2012 -Rolls-Royce Phantom Drophead Coupe Convertible 2012 -Rolls-Royce Ghost Sedan 2012 -Rolls-Royce Phantom Sedan 2012 -Scion xD Hatchback 2012 -Spyker C8 Convertible 2009 -Spyker C8 Coupe 2009 -Suzuki Aerio Sedan 2007 -Suzuki Kizashi Sedan 2012 -Suzuki SX4 Hatchback 2012 -Suzuki SX4 Sedan 2012 -Tesla Model S Sedan 2012 -Toyota Sequoia SUV 2012 -Toyota Camry Sedan 2012 -Toyota Corolla Sedan 2012 -Toyota 4Runner SUV 2012 -Volkswagen Golf Hatchback 2012 -Volkswagen Golf Hatchback 1991 -Volkswagen Beetle Hatchback 2012 -Volvo C30 Hatchback 2012 -Volvo 240 Sedan 1993 -Volvo XC90 SUV 2007 -smart fortwo Convertible 2012 diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py deleted file mode 100644 index a76b2dba270..00000000000 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ /dev/null @@ -1,117 +0,0 @@ -import pathlib -from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union - -from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.datapoints import BoundingBox -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, - path_comparator, - read_categories_file, - read_mat, -) - -from .._api import register_dataset, register_info - - -class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]): - def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None: - self.datapipe = datapipe - - def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]: - for _, file in self.datapipe: - data = read_mat(file, squeeze_me=True) - for ann in data["annotations"]: - yield tuple(ann) # type: ignore[misc] - - -NAME = "stanford-cars" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class StanfordCars(Dataset): - """Stanford Cars dataset. - homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", - dependencies=scipy - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "test"}) - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",)) - - _URL_ROOT = "https://ai.stanford.edu/~jkrause/" - _URLS = { - "train": f"{_URL_ROOT}car196/cars_train.tgz", - "test": f"{_URL_ROOT}car196/cars_test.tgz", - "cars_test_annos_withlabels": f"{_URL_ROOT}car196/cars_test_annos_withlabels.mat", - "car_devkit": f"{_URL_ROOT}cars/car_devkit.tgz", - } - - _CHECKSUM = { - "train": "b97deb463af7d58b6bfaa18b2a4de9829f0f79e8ce663dfa9261bf7810e9accd", - "test": "bffea656d6f425cba3c91c6d83336e4c5f86c6cffd8975b0f375d3a10da8e243", - "cars_test_annos_withlabels": "790f75be8ea34eeded134cc559332baf23e30e91367e9ddca97d26ed9b895f05", - "car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288", - } - - def _resources(self) -> List[OnlineResource]: - resources: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUM[self._split])] - if self._split == "train": - resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"])) - - else: - resources.append( - HttpResource( - self._URLS["cars_test_annos_withlabels"], sha256=self._CHECKSUM["cars_test_annos_withlabels"] - ) - ) - return resources - - def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int, int, int, str]]) -> Dict[str, Any]: - image, target = data - path, buffer = image - image = EncodedImage.from_file(buffer) - - return dict( - path=path, - image=image, - label=Label(target[4] - 1, categories=self._categories), - bounding_box=BoundingBox(target[:4], format="xyxy", spatial_size=image.spatial_size), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - - images_dp, targets_dp = resource_dps - if self._split == "train": - targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat")) - targets_dp = StanfordCarsLabelReader(targets_dp) - dp = Zipper(images_dp, targets_dp) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def _generate_categories(self) -> List[str]: - resources = self._resources() - - devkit_dp = resources[1].load(self._root) - meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat")) - _, meta_file = next(iter(meta_dp)) - - return list(read_mat(meta_file, squeeze_me=True)["class_names"]) - - def __len__(self) -> int: - return 8_144 if self._split == "train" else 8_041 diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py deleted file mode 100644 index 94de4cf42c3..00000000000 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ /dev/null @@ -1,84 +0,0 @@ -import pathlib -from typing import Any, BinaryIO, Dict, List, Tuple, Union - -import numpy as np -from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher -from torchvision.datapoints import Image -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat - -from .._api import register_dataset, register_info - -NAME = "svhn" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=[str(c) for c in range(10)]) - - -@register_dataset(NAME) -class SVHN(Dataset): - """SVHN Dataset. - homepage="http://ufldl.stanford.edu/housenumbers/", - dependencies = scipy - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "test", "extra"}) - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",)) - - _CHECKSUMS = { - "train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8", - "test": "cdce80dfb2a2c4c6160906d0bd7c68ec5a99d7ca4831afa54f09182025b6a75b", - "extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3", - } - - def _resources(self) -> List[OnlineResource]: - data = HttpResource( - f"http://ufldl.stanford.edu/housenumbers/{self._split}_32x32.mat", - sha256=self._CHECKSUMS[self._split], - ) - - return [data] - - def _read_images_and_labels(self, data: Tuple[str, BinaryIO]) -> List[Tuple[np.ndarray, np.ndarray]]: - _, buffer = data - content = read_mat(buffer) - return list( - zip( - content["X"].transpose((3, 0, 1, 2)), - content["y"].squeeze(), - ) - ) - - def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any]: - image_array, label_array = data - - return dict( - image=Image(image_array.transpose((2, 0, 1))), - label=Label(int(label_array) % 10, categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = Mapper(dp, self._read_images_and_labels) - dp = UnBatcher(dp) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - "train": 73_257, - "test": 26_032, - "extra": 531_131, - }[self._split] diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py deleted file mode 100644 index b5486669e21..00000000000 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ /dev/null @@ -1,70 +0,0 @@ -import pathlib -from typing import Any, Dict, List, Union - -import torch -from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper -from torchvision.datapoints import Image -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling - -from .._api import register_dataset, register_info - -NAME = "usps" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=[str(c) for c in range(10)]) - - -@register_dataset(NAME) -class USPS(Dataset): - """USPS Dataset - homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - skip_integrity_check: bool = False, - ) -> None: - self._split = self._verify_str_arg(split, "split", {"train", "test"}) - - self._categories = _info()["categories"] - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass" - - _RESOURCES = { - "train": HttpResource( - f"{_URL}/usps.bz2", sha256="3771e9dd6ba685185f89867b6e249233dd74652389f263963b3b741e994b034f" - ), - "test": HttpResource( - f"{_URL}/usps.t.bz2", sha256="a9c0164e797d60142a50604917f0baa604f326e9a689698763793fa5d12ffc4e" - ), - } - - def _resources(self) -> List[OnlineResource]: - return [USPS._RESOURCES[self._split]] - - def _prepare_sample(self, line: str) -> Dict[str, Any]: - label, *values = line.strip().split(" ") - values = [float(value.split(":")[1]) for value in values] - pixels = torch.tensor(values).add_(1).div_(2) - return dict( - image=Image(pixels.reshape(16, 16)), - label=Label(int(label) - 1, categories=self._categories), - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - dp = Decompressor(resource_dps[0]) - dp = LineReader(dp, decode=True, return_path=False) - dp = hint_shuffling(dp) - dp = hint_sharding(dp) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return 7_291 if self._split == "train" else 2_007 diff --git a/torchvision/prototype/datasets/_builtin/voc.categories b/torchvision/prototype/datasets/_builtin/voc.categories deleted file mode 100644 index febc0012ab3..00000000000 --- a/torchvision/prototype/datasets/_builtin/voc.categories +++ /dev/null @@ -1,21 +0,0 @@ -__background__ -aeroplane -bicycle -bird -boat -bottle -bus -car -cat -chair -cow -diningtable -dog -horse -motorbike -person -pottedplant -sheep -sofa -train -tvmonitor diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py deleted file mode 100644 index a13cfb764e4..00000000000 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ /dev/null @@ -1,222 +0,0 @@ -import enum -import functools -import pathlib -from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union -from xml.etree import ElementTree - -from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.datapoints import BoundingBox -from torchvision.datasets import VOCDetection -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import ( - getitem, - hint_sharding, - hint_shuffling, - INFINITE_BUFFER_SIZE, - path_accessor, - path_comparator, - read_categories_file, -) - -from .._api import register_dataset, register_info - -NAME = "voc" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - return dict(categories=read_categories_file(NAME)) - - -@register_dataset(NAME) -class VOC(Dataset): - """ - - **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/ - """ - - def __init__( - self, - root: Union[str, pathlib.Path], - *, - split: str = "train", - year: str = "2012", - task: str = "detection", - skip_integrity_check: bool = False, - ) -> None: - self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012")) - if split == "test" and year != "2007": - raise ValueError("`split='test'` is only available for `year='2007'`") - else: - self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test")) - self._task = self._verify_str_arg(task, "task", ("detection", "segmentation")) - - self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass" - self._split_folder = "Main" if task == "detection" else "Segmentation" - - self._categories = _info()["categories"] - - super().__init__(root, skip_integrity_check=skip_integrity_check) - - _TRAIN_VAL_ARCHIVES = { - "2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"), - "2008": ("VOCtrainval_14-Jul-2008.tar", "7f0ca53c1b5a838fbe946965fc106c6e86832183240af5c88e3f6c306318d42e"), - "2009": ("VOCtrainval_11-May-2009.tar", "11cbe1741fb5bdadbbca3c08e9ec62cd95c14884845527d50847bc2cf57e7fd6"), - "2010": ("VOCtrainval_03-May-2010.tar", "1af4189cbe44323ab212bff7afbc7d0f55a267cc191eb3aac911037887e5c7d4"), - "2011": ("VOCtrainval_25-May-2011.tar", "0a7f5f5d154f7290ec65ec3f78b72ef72c6d93ff6d79acd40dc222a9ee5248ba"), - "2012": ("VOCtrainval_11-May-2012.tar", "e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb"), - } - _TEST_ARCHIVES = { - "2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892") - } - - def _resources(self) -> List[OnlineResource]: - file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year] - archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256) - return [archive] - - def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool: - path = pathlib.Path(data[0]) - return name in path.parent.parts[-depth:] - - class _Demux(enum.IntEnum): - SPLIT = 0 - IMAGES = 1 - ANNS = 2 - - def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: - if self._is_in_folder(data, name="ImageSets", depth=2): - return self._Demux.SPLIT - elif self._is_in_folder(data, name="JPEGImages"): - return self._Demux.IMAGES - elif self._is_in_folder(data, name=self._anns_folder): - return self._Demux.ANNS - else: - return None - - def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: - ann = cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"]) - buffer.close() - return ann - - def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: - anns = self._parse_detection_ann(buffer) - instances = anns["object"] - return dict( - bounding_boxes=BoundingBox( - [ - [int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")] - for instance in instances - ], - format="xyxy", - spatial_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), - ), - labels=Label( - [self._categories.index(instance["name"]) for instance in instances], categories=self._categories - ), - ) - - def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]: - return dict(segmentation=EncodedImage.from_file(buffer)) - - def _prepare_sample( - self, - data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], - ) -> Dict[str, Any]: - split_and_image_data, ann_data = data - _, image_data = split_and_image_data - image_path, image_buffer = image_data - ann_path, ann_buffer = ann_data - - return dict( - (self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer), - image_path=image_path, - image=EncodedImage.from_file(image_buffer), - ann_path=ann_path, - ) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] - split_dp, images_dp, anns_dp = Demultiplexer( - archive_dp, - 3, - self._classify_archive, - drop_none=True, - buffer_size=INFINITE_BUFFER_SIZE, - ) - - split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder)) - split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) - split_dp = LineReader(split_dp, decode=True) - split_dp = hint_shuffling(split_dp) - split_dp = hint_sharding(split_dp) - - dp = split_dp - for level, data_dp in enumerate((images_dp, anns_dp)): - dp = IterKeyZipper( - dp, - data_dp, - key_fn=getitem(*[0] * level, 1), - ref_key_fn=path_accessor("stem"), - buffer_size=INFINITE_BUFFER_SIZE, - ) - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { - ("train", "2007", "detection"): 2_501, - ("train", "2007", "segmentation"): 209, - ("train", "2008", "detection"): 2_111, - ("train", "2008", "segmentation"): 511, - ("train", "2009", "detection"): 3_473, - ("train", "2009", "segmentation"): 749, - ("train", "2010", "detection"): 4_998, - ("train", "2010", "segmentation"): 964, - ("train", "2011", "detection"): 5_717, - ("train", "2011", "segmentation"): 1_112, - ("train", "2012", "detection"): 5_717, - ("train", "2012", "segmentation"): 1_464, - ("val", "2007", "detection"): 2_510, - ("val", "2007", "segmentation"): 213, - ("val", "2008", "detection"): 2_221, - ("val", "2008", "segmentation"): 512, - ("val", "2009", "detection"): 3_581, - ("val", "2009", "segmentation"): 750, - ("val", "2010", "detection"): 5_105, - ("val", "2010", "segmentation"): 964, - ("val", "2011", "detection"): 5_823, - ("val", "2011", "segmentation"): 1_111, - ("val", "2012", "detection"): 5_823, - ("val", "2012", "segmentation"): 1_449, - ("trainval", "2007", "detection"): 5_011, - ("trainval", "2007", "segmentation"): 422, - ("trainval", "2008", "detection"): 4_332, - ("trainval", "2008", "segmentation"): 1_023, - ("trainval", "2009", "detection"): 7_054, - ("trainval", "2009", "segmentation"): 1_499, - ("trainval", "2010", "detection"): 10_103, - ("trainval", "2010", "segmentation"): 1_928, - ("trainval", "2011", "detection"): 11_540, - ("trainval", "2011", "segmentation"): 2_223, - ("trainval", "2012", "detection"): 11_540, - ("trainval", "2012", "segmentation"): 2_913, - ("test", "2007", "detection"): 4_952, - ("test", "2007", "segmentation"): 210, - }[(self._split, self._year, self._task)] - - def _filter_anns(self, data: Tuple[str, Any]) -> bool: - return self._classify_archive(data) == self._Demux.ANNS - - def _generate_categories(self) -> List[str]: - self._task = "detection" - resources = self._resources() - - archive_dp = resources[0].load(self._root) - dp = Filter(archive_dp, self._filter_anns) - dp = Mapper(dp, self._parse_detection_ann, input_col=1) - - categories = sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) - # We add a background category to be used during segmentation - categories.insert(0, "__background__") - - return categories diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py deleted file mode 100644 index 0a37df03add..00000000000 --- a/torchvision/prototype/datasets/_folder.py +++ /dev/null @@ -1,66 +0,0 @@ -import functools -import os -import os.path -import pathlib -from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union - -from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import EncodedData, EncodedImage -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling - - -__all__ = ["from_data_folder", "from_image_folder"] - - -def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: - rel_path = pathlib.Path(path).relative_to(root) - return rel_path.is_dir() or rel_path.parent != pathlib.Path(".") - - -def _prepare_sample( - data: Tuple[str, BinaryIO], - *, - root: pathlib.Path, - categories: List[str], -) -> Dict[str, Any]: - path, buffer = data - category = pathlib.Path(path).relative_to(root).parts[0] - return dict( - path=path, - data=EncodedData.from_file(buffer), - label=Label.from_category(category, categories=categories), - ) - - -def from_data_folder( - root: Union[str, pathlib.Path], - *, - valid_extensions: Optional[Collection[str]] = None, - recursive: bool = True, -) -> Tuple[IterDataPipe, List[str]]: - root = pathlib.Path(root).expanduser().resolve() - categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) - masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" - dp = FileLister(str(root), recursive=recursive, masks=masks) - dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root)) - dp = hint_sharding(dp) - dp = hint_shuffling(dp) - dp = FileOpener(dp, mode="rb") - return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories - - -def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = EncodedImage(sample.pop("data").data) - return sample - - -def from_image_folder( - root: Union[str, pathlib.Path], - *, - valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), - **kwargs: Any, -) -> Tuple[IterDataPipe, List[str]]: - valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] - dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs) - return Mapper(dp, _data_to_image_key), categories diff --git a/torchvision/prototype/datasets/_home.py b/torchvision/prototype/datasets/_home.py deleted file mode 100644 index e5a89c4bdf3..00000000000 --- a/torchvision/prototype/datasets/_home.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -from typing import Optional - -import torchvision._internally_replaced_utils as _iru - - -def home(root: Optional[str] = None) -> str: - if root is not None: - _iru._HOME = root - return _iru._HOME - - root = os.getenv("TORCHVISION_DATASETS_HOME") - if root is not None: - return root - - return _iru._HOME - - -def use_sharded_dataset(use: Optional[bool] = None) -> bool: - if use is not None: - _iru._USE_SHARDED_DATASETS = use - return _iru._USE_SHARDED_DATASETS - - use = os.getenv("TORCHVISION_SHARDED_DATASETS") - if use is not None: - return use == "1" - - return _iru._USE_SHARDED_DATASETS diff --git a/torchvision/prototype/datasets/benchmark.py b/torchvision/prototype/datasets/benchmark.py deleted file mode 100644 index 104ef95c9ae..00000000000 --- a/torchvision/prototype/datasets/benchmark.py +++ /dev/null @@ -1,661 +0,0 @@ -# type: ignore - -import argparse -import collections.abc -import contextlib -import inspect -import itertools -import os -import os.path -import pathlib -import shutil -import sys -import tempfile -import time -import unittest.mock -import warnings - -import torch -from torch.utils.data import DataLoader -from torch.utils.data.dataloader_experimental import DataLoader2 -from torchvision import datasets as legacy_datasets -from torchvision.datasets.utils import extract_archive -from torchvision.prototype import datasets as new_datasets -from torchvision.transforms import PILToTensor - - -def main( - name, - *, - variant=None, - legacy=True, - new=True, - start=True, - iteration=True, - num_starts=3, - num_samples=10_000, - temp_root=None, - num_workers=0, -): - benchmarks = [ - benchmark - for benchmark in DATASET_BENCHMARKS - if benchmark.name == name and (variant is None or benchmark.variant == variant) - ] - if not benchmarks: - msg = f"No DatasetBenchmark available for dataset '{name}'" - if variant is not None: - msg += f" and variant '{variant}'" - raise ValueError(msg) - - for benchmark in benchmarks: - print("#" * 80) - print(f"{benchmark.name}" + (f" ({benchmark.variant})" if benchmark.variant is not None else "")) - - if legacy and start: - print( - "legacy", - "cold_start", - Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts), - ) - print( - "legacy", - "warm_start", - Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts), - ) - - if legacy and iteration: - print( - "legacy", - "iteration", - Measurement.iterations_per_time( - benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples) - ), - ) - - if new and start: - print( - "new", - "cold_start", - Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts), - ) - - if new and iteration: - print( - "new", - "iteration", - Measurement.iterations_per_time( - benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples) - ), - ) - - -class DatasetBenchmark: - def __init__( - self, - name: str, - *, - variant=None, - legacy_cls=None, - new_config=None, - legacy_config_map=None, - legacy_special_options_map=None, - prepare_legacy_root=None, - ): - self.name = name - self.variant = variant - - self.new_raw_dataset = new_datasets._api.find(name) - self.legacy_cls = legacy_cls or self._find_legacy_cls() - - if new_config is None: - new_config = self.new_raw_dataset.default_config - elif isinstance(new_config, dict): - new_config = self.new_raw_dataset.info.make_config(**new_config) - self.new_config = new_config - - self.legacy_config_map = legacy_config_map - self.legacy_special_options_map = legacy_special_options_map or self._legacy_special_options_map - self.prepare_legacy_root = prepare_legacy_root - - def new_dataset(self, *, num_workers=0): - return DataLoader2(new_datasets.load(self.name, **self.new_config), num_workers=num_workers) - - def new_cold_start(self, *, num_workers): - def fn(timer): - with timer: - dataset = self.new_dataset(num_workers=num_workers) - next(iter(dataset)) - - return fn - - def new_iteration(self, *, num_samples, num_workers): - def fn(timer): - dataset = self.new_dataset(num_workers=num_workers) - num_sample = 0 - with timer: - for _ in dataset: - num_sample += 1 - if num_sample == num_samples: - break - - return num_sample - - return fn - - def suppress_output(self): - @contextlib.contextmanager - def context_manager(): - with open(os.devnull, "w") as devnull: - with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull): - yield - - return context_manager() - - def legacy_dataset(self, root, *, num_workers=0, download=None): - legacy_config = self.legacy_config_map(self, root) if self.legacy_config_map else dict() - - special_options = self.legacy_special_options_map(self) - if "download" in special_options and download is not None: - special_options["download"] = download - - with self.suppress_output(): - return DataLoader( - self.legacy_cls(legacy_config.pop("root", str(root)), **legacy_config, **special_options), - shuffle=True, - num_workers=num_workers, - ) - - @contextlib.contextmanager - def patch_download_and_integrity_checks(self): - patches = [ - ("download_url", dict()), - ("download_file_from_google_drive", dict()), - ("check_integrity", dict(new=lambda path, md5=None: os.path.isfile(path))), - ] - dataset_module = sys.modules[self.legacy_cls.__module__] - utils_module = legacy_datasets.utils - with contextlib.ExitStack() as stack: - for name, patch_kwargs in patches: - patch_module = dataset_module if name in dir(dataset_module) else utils_module - stack.enter_context(unittest.mock.patch(f"{patch_module.__name__}.{name}", **patch_kwargs)) - - yield stack - - def _find_resource_file_names(self): - info = self.new_raw_dataset.info - valid_options = info._valid_options - - file_names = set() - for options in ( - dict(zip(valid_options.keys(), values)) for values in itertools.product(*valid_options.values()) - ): - resources = self.new_raw_dataset.resources(info.make_config(**options)) - file_names.update([resource.file_name for resource in resources]) - - return file_names - - @contextlib.contextmanager - def legacy_root(self, temp_root): - new_root = pathlib.Path(new_datasets.home()) / self.name - legacy_root = pathlib.Path(tempfile.mkdtemp(dir=temp_root)) - - if os.stat(new_root).st_dev != os.stat(legacy_root).st_dev: - warnings.warn( - "The temporary root directory for the legacy dataset was created on a different storage device than " - "the raw data that is used by the new dataset. If the devices have different I/O stats, this will " - "distort the benchmark. You can use the '--temp-root' flag to relocate the root directory of the " - "temporary directories.", - RuntimeWarning, - ) - - try: - for file_name in self._find_resource_file_names(): - (legacy_root / file_name).symlink_to(new_root / file_name) - - if self.prepare_legacy_root: - self.prepare_legacy_root(self, legacy_root) - - with self.patch_download_and_integrity_checks(): - yield legacy_root - finally: - shutil.rmtree(legacy_root) - - def legacy_cold_start(self, temp_root, *, num_workers): - def fn(timer): - with self.legacy_root(temp_root) as root: - with timer: - dataset = self.legacy_dataset(root, num_workers=num_workers) - next(iter(dataset)) - - return fn - - def legacy_warm_start(self, temp_root, *, num_workers): - def fn(timer): - with self.legacy_root(temp_root) as root: - self.legacy_dataset(root, num_workers=num_workers) - with timer: - dataset = self.legacy_dataset(root, num_workers=num_workers, download=False) - next(iter(dataset)) - - return fn - - def legacy_iteration(self, temp_root, *, num_samples, num_workers): - def fn(timer): - with self.legacy_root(temp_root) as root: - dataset = self.legacy_dataset(root, num_workers=num_workers) - with timer: - for num_sample, _ in enumerate(dataset, 1): - if num_sample == num_samples: - break - - return num_sample - - return fn - - def _find_legacy_cls(self): - legacy_clss = { - name.lower(): dataset_class - for name, dataset_class in legacy_datasets.__dict__.items() - if isinstance(dataset_class, type) and issubclass(dataset_class, legacy_datasets.VisionDataset) - } - try: - return legacy_clss[self.name] - except KeyError as error: - raise RuntimeError( - f"Can't determine the legacy dataset class for '{self.name}' automatically. " - f"Please set the 'legacy_cls' keyword argument manually." - ) from error - - _SPECIAL_KWARGS = { - "transform", - "target_transform", - "transforms", - "download", - } - - @staticmethod - def _legacy_special_options_map(benchmark): - available_parameters = set() - - for cls in benchmark.legacy_cls.__mro__: - if cls is legacy_datasets.VisionDataset: - break - - available_parameters.update(inspect.signature(cls.__init__).parameters) - - available_special_kwargs = benchmark._SPECIAL_KWARGS.intersection(available_parameters) - - special_options = dict() - - if "download" in available_special_kwargs: - special_options["download"] = True - - if "transform" in available_special_kwargs: - special_options["transform"] = PILToTensor() - if "target_transform" in available_special_kwargs: - special_options["target_transform"] = torch.tensor - elif "transforms" in available_special_kwargs: - special_options["transforms"] = JointTransform(PILToTensor(), PILToTensor()) - - return special_options - - -class Measurement: - @classmethod - def time(cls, fn, *, number): - results = Measurement._timeit(fn, number=number) - times = torch.tensor(tuple(zip(*results))[1]) - return cls._format(times, unit="s") - - @classmethod - def iterations_per_time(cls, fn): - num_samples, time = Measurement._timeit(fn, number=1)[0] - iterations_per_second = torch.tensor(num_samples) / torch.tensor(time) - return cls._format(iterations_per_second, unit="it/s") - - class Timer: - def __init__(self): - self._start = None - self._stop = None - - def __enter__(self): - self._start = time.perf_counter() - - def __exit__(self, exc_type, exc_val, exc_tb): - self._stop = time.perf_counter() - - @property - def delta(self): - if self._start is None: - raise RuntimeError() - elif self._stop is None: - raise RuntimeError() - return self._stop - self._start - - @classmethod - def _timeit(cls, fn, number): - results = [] - for _ in range(number): - timer = cls.Timer() - output = fn(timer) - results.append((output, timer.delta)) - return results - - @classmethod - def _format(cls, measurements, *, unit): - measurements = torch.as_tensor(measurements).to(torch.float64).flatten() - if measurements.numel() == 1: - # TODO format that into engineering format - return f"{float(measurements):.3f} {unit}" - - mean, std = Measurement._compute_mean_and_std(measurements) - # TODO format that into engineering format - return f"{mean:.3f} ± {std:.3f} {unit}" - - @classmethod - def _compute_mean_and_std(cls, t): - mean = float(t.mean()) - std = float(t.std(0, unbiased=t.numel() > 1)) - return mean, std - - -def no_split(benchmark, root): - legacy_config = dict(benchmark.new_config) - del legacy_config["split"] - return legacy_config - - -def bool_split(name="train"): - def legacy_config_map(benchmark, root): - legacy_config = dict(benchmark.new_config) - legacy_config[name] = legacy_config.pop("split") == "train" - return legacy_config - - return legacy_config_map - - -def base_folder(rel_folder=None): - if rel_folder is None: - - def rel_folder(benchmark): - return benchmark.name - - elif not callable(rel_folder): - name = rel_folder - - def rel_folder(_): - return name - - def prepare_legacy_root(benchmark, root): - files = list(root.glob("*")) - folder = root / rel_folder(benchmark) - folder.mkdir(parents=True) - for file in files: - shutil.move(str(file), str(folder)) - - return folder - - return prepare_legacy_root - - -class JointTransform: - def __init__(self, *transforms): - self.transforms = transforms - - def __call__(self, *inputs): - if len(inputs) == 1 and isinstance(inputs, collections.abc.Sequence): - inputs = inputs[0] - - if len(inputs) != len(self.transforms): - raise RuntimeError( - f"The number of inputs and transforms mismatches: {len(inputs)} != {len(self.transforms)}." - ) - - return tuple(transform(input) for transform, input in zip(self.transforms, inputs)) - - -def caltech101_legacy_config_map(benchmark, root): - legacy_config = no_split(benchmark, root) - # The new dataset always returns the category and annotation - legacy_config["target_type"] = ("category", "annotation") - return legacy_config - - -mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw") - - -def mnist_legacy_config_map(benchmark, root): - return dict(train=benchmark.new_config.split == "train") - - -def emnist_prepare_legacy_root(benchmark, root): - folder = mnist_base_folder(benchmark, root) - shutil.move(str(folder / "emnist-gzip.zip"), str(folder / "gzip.zip")) - return folder - - -def emnist_legacy_config_map(benchmark, root): - legacy_config = mnist_legacy_config_map(benchmark, root) - legacy_config["split"] = benchmark.new_config.image_set.replace("_", "").lower() - return legacy_config - - -def qmnist_legacy_config_map(benchmark, root): - legacy_config = mnist_legacy_config_map(benchmark, root) - legacy_config["what"] = benchmark.new_config.split - # The new dataset always returns the full label - legacy_config["compat"] = False - return legacy_config - - -def coco_legacy_config_map(benchmark, root): - images, _ = benchmark.new_raw_dataset.resources(benchmark.new_config) - return dict( - root=str(root / pathlib.Path(images.file_name).stem), - annFile=str( - root / "annotations" / f"{benchmark.variant}_{benchmark.new_config.split}{benchmark.new_config.year}.json" - ), - ) - - -def coco_prepare_legacy_root(benchmark, root): - images, annotations = benchmark.new_raw_dataset.resources(benchmark.new_config) - extract_archive(str(root / images.file_name)) - extract_archive(str(root / annotations.file_name)) - - -DATASET_BENCHMARKS = [ - DatasetBenchmark( - "caltech101", - legacy_config_map=caltech101_legacy_config_map, - prepare_legacy_root=base_folder(), - legacy_special_options_map=lambda config: dict( - download=True, - transform=PILToTensor(), - target_transform=JointTransform(torch.tensor, torch.tensor), - ), - ), - DatasetBenchmark( - "caltech256", - legacy_config_map=no_split, - prepare_legacy_root=base_folder(), - ), - DatasetBenchmark( - "celeba", - prepare_legacy_root=base_folder(), - legacy_config_map=lambda benchmark: dict( - split="valid" if benchmark.new_config.split == "val" else benchmark.new_config.split, - # The new dataset always returns all annotations - target_type=("attr", "identity", "bbox", "landmarks"), - ), - ), - DatasetBenchmark( - "cifar10", - legacy_config_map=bool_split(), - ), - DatasetBenchmark( - "cifar100", - legacy_config_map=bool_split(), - ), - DatasetBenchmark( - "emnist", - prepare_legacy_root=emnist_prepare_legacy_root, - legacy_config_map=emnist_legacy_config_map, - ), - DatasetBenchmark( - "fashionmnist", - prepare_legacy_root=mnist_base_folder, - legacy_config_map=mnist_legacy_config_map, - ), - DatasetBenchmark( - "kmnist", - prepare_legacy_root=mnist_base_folder, - legacy_config_map=mnist_legacy_config_map, - ), - DatasetBenchmark( - "mnist", - prepare_legacy_root=mnist_base_folder, - legacy_config_map=mnist_legacy_config_map, - ), - DatasetBenchmark( - "qmnist", - prepare_legacy_root=mnist_base_folder, - legacy_config_map=mnist_legacy_config_map, - ), - DatasetBenchmark( - "sbd", - legacy_cls=legacy_datasets.SBDataset, - legacy_config_map=lambda benchmark: dict( - image_set=benchmark.new_config.split, - mode="boundaries" if benchmark.new_config.boundaries else "segmentation", - ), - legacy_special_options_map=lambda benchmark: dict( - download=True, - transforms=JointTransform( - PILToTensor(), torch.tensor if benchmark.new_config.boundaries else PILToTensor() - ), - ), - ), - DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection), - DatasetBenchmark("imagenet", legacy_cls=legacy_datasets.ImageNet), - DatasetBenchmark( - "coco", - variant="instances", - legacy_cls=legacy_datasets.CocoDetection, - new_config=dict(split="train", annotations="instances"), - legacy_config_map=coco_legacy_config_map, - prepare_legacy_root=coco_prepare_legacy_root, - legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None), - ), - DatasetBenchmark( - "coco", - variant="captions", - legacy_cls=legacy_datasets.CocoCaptions, - new_config=dict(split="train", annotations="captions"), - legacy_config_map=coco_legacy_config_map, - prepare_legacy_root=coco_prepare_legacy_root, - legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None), - ), -] - - -def parse_args(argv=None): - parser = argparse.ArgumentParser( - prog="torchvision.prototype.datasets.benchmark.py", - description="Utility to benchmark new datasets against their legacy variants.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument("name", help="Name of the dataset to benchmark.") - parser.add_argument( - "--variant", help="Variant of the dataset. If omitted all available variants will be benchmarked." - ) - - parser.add_argument( - "-n", - "--num-starts", - type=int, - default=3, - help="Number of warm and cold starts of each benchmark. Default to 3.", - ) - parser.add_argument( - "-N", - "--num-samples", - type=int, - default=10_000, - help="Maximum number of samples to draw during iteration benchmarks. Defaults to 10_000.", - ) - - parser.add_argument( - "--nl", - "--no-legacy", - dest="legacy", - action="store_false", - help="Skip legacy benchmarks.", - ) - parser.add_argument( - "--nn", - "--no-new", - dest="new", - action="store_false", - help="Skip new benchmarks.", - ) - parser.add_argument( - "--ns", - "--no-start", - dest="start", - action="store_false", - help="Skip start benchmarks.", - ) - parser.add_argument( - "--ni", - "--no-iteration", - dest="iteration", - action="store_false", - help="Skip iteration benchmarks.", - ) - - parser.add_argument( - "-t", - "--temp-root", - type=pathlib.Path, - help=( - "Root of the temporary legacy root directories. Use this if your system default temporary directory is on " - "another storage device as the raw data to avoid distortions due to differing I/O stats." - ), - ) - parser.add_argument( - "-j", - "--num-workers", - type=int, - default=0, - help=( - "Number of subprocesses used to load the data. Setting this to 0 (default) will load all data in the main " - "process and thus disable multi-processing." - ), - ) - - return parser.parse_args(argv or sys.argv[1:]) - - -if __name__ == "__main__": - args = parse_args() - - try: - main( - args.name, - variant=args.variant, - legacy=args.legacy, - new=args.new, - start=args.start, - iteration=args.iteration, - num_starts=args.num_starts, - num_samples=args.num_samples, - temp_root=args.temp_root, - num_workers=args.num_workers, - ) - except Exception as error: - msg = str(error) - print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr) - sys.exit(1) diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py deleted file mode 100644 index 6d4e854fe34..00000000000 --- a/torchvision/prototype/datasets/generate_category_files.py +++ /dev/null @@ -1,61 +0,0 @@ -# type: ignore - -import argparse -import csv -import sys - -from torchvision.prototype import datasets -from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR - - -def main(*names, force=False): - for name in names: - path = BUILTIN_DIR / f"{name}.categories" - if path.exists() and not force: - continue - - dataset = datasets.load(name) - try: - categories = dataset._generate_categories() - except NotImplementedError: - continue - - with open(path, "w") as file: - writer = csv.writer(file, lineterminator="\n") - for category in categories: - writer.writerow((category,) if isinstance(category, str) else category) - - -def parse_args(argv=None): - parser = argparse.ArgumentParser(prog="torchvision.prototype.datasets.generate_category_files.py") - - parser.add_argument( - "names", - nargs="*", - type=str, - help="Names of datasets to generate category files for. If omitted, all datasets will be used.", - ) - parser.add_argument( - "-f", - "--force", - action="store_true", - help="Force regeneration of category files.", - ) - - args = parser.parse_args(argv or sys.argv[1:]) - - if not args.names: - args.names = datasets.list_datasets() - - return args - - -if __name__ == "__main__": - args = parse_args() - - try: - main(*args.names, force=args.force) - except Exception as error: - msg = str(error) - print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr) - sys.exit(1) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py deleted file mode 100644 index 3fdb53eec43..00000000000 --- a/torchvision/prototype/datasets/utils/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from . import _internal # usort: skip -from ._dataset import Dataset -from ._encoded import EncodedData, EncodedImage -from ._resource import GDriveResource, HttpResource, KaggleDownloadResource, ManualDownloadResource, OnlineResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py deleted file mode 100644 index 0d1cc2b1560..00000000000 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ /dev/null @@ -1,57 +0,0 @@ -import abc -import importlib -import pathlib -from typing import Any, Collection, Dict, Iterator, List, Optional, Sequence, Union - -from torchdata.datapipes.iter import IterDataPipe -from torchvision.datasets.utils import verify_str_arg - -from ._resource import OnlineResource - - -class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC): - @staticmethod - def _verify_str_arg( - value: str, - arg: Optional[str] = None, - valid_values: Optional[Collection[str]] = None, - *, - custom_msg: Optional[str] = None, - ) -> str: - return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg) - - def __init__( - self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = () - ) -> None: - for dependency in dependencies: - try: - importlib.import_module(dependency) - except ModuleNotFoundError: - raise ModuleNotFoundError( - f"{type(self).__name__}() depends on the third-party package '{dependency}'. " - f"Please install it, for example with `pip install {dependency}`." - ) from None - - self._root = pathlib.Path(root).expanduser().resolve() - resources = [ - resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources() - ] - self._dp = self._datapipe(resources) - - def __iter__(self) -> Iterator[Dict[str, Any]]: - yield from self._dp - - @abc.abstractmethod - def _resources(self) -> List[OnlineResource]: - pass - - @abc.abstractmethod - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - pass - - @abc.abstractmethod - def __len__(self) -> int: - pass - - def _generate_categories(self) -> Sequence[Union[str, Sequence[str]]]: - raise NotImplementedError diff --git a/torchvision/prototype/datasets/utils/_encoded.py b/torchvision/prototype/datasets/utils/_encoded.py deleted file mode 100644 index 8adc1e57acb..00000000000 --- a/torchvision/prototype/datasets/utils/_encoded.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -import os -import sys -from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union - -import PIL.Image -import torch - -from torchvision.datapoints._datapoint import Datapoint -from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer - -D = TypeVar("D", bound="EncodedData") - - -class EncodedData(Datapoint): - @classmethod - def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: - return tensor.as_subclass(cls) - - def __new__( - cls, - data: Any, - *, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, - ) -> EncodedData: - tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? - return cls._wrap(tensor) - - @classmethod - def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: - return cls._wrap(tensor) - - @classmethod - def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D: - encoded_data = cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs) - file.close() - return encoded_data - - @classmethod - def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D: - with open(path, "rb") as file: - return cls.from_file(file, **kwargs) - - -class EncodedImage(EncodedData): - # TODO: Use @functools.cached_property if we can depend on Python 3.8 - @property - def spatial_size(self) -> Tuple[int, int]: - if not hasattr(self, "_spatial_size"): - with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image: - self._spatial_size = image.height, image.width - - return self._spatial_size diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py deleted file mode 100644 index f8a44b627e8..00000000000 --- a/torchvision/prototype/datasets/utils/_internal.py +++ /dev/null @@ -1,194 +0,0 @@ -import csv -import functools -import pathlib -import pickle -from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union - -import torch -import torch.distributed as dist -import torch.utils.data -from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler -from torchvision.prototype.utils._internal import fromfile - - -__all__ = [ - "INFINITE_BUFFER_SIZE", - "BUILTIN_DIR", - "read_mat", - "MappingIterator", - "getitem", - "path_accessor", - "path_comparator", - "read_flo", - "hint_sharding", - "hint_shuffling", -] - -K = TypeVar("K") -D = TypeVar("D") - -# pseudo-infinite until a true infinite buffer is supported by all datapipes -INFINITE_BUFFER_SIZE = 1_000_000_000 - -BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" - - -def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any: - try: - import scipy.io as sio - except ImportError as error: - raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error - - data = sio.loadmat(buffer, **kwargs) - buffer.close() - return data - - -class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): - def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None: - self.datapipe = datapipe - self.drop_key = drop_key - - def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]: - for mapping in self.datapipe: - yield from iter(mapping.values() if self.drop_key else mapping.items()) - - -def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any: - for item in items: - obj = obj[item] - return obj - - -def getitem(*items: Any) -> Callable[[Any], Any]: - return functools.partial(_getitem_closure, items=items) - - -def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any: - for attr in attrs: - obj = getattr(obj, attr) - return obj - - -def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> Any: - return _getattr_closure(path, attrs=name.split(".")) - - -def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D: - return getter(pathlib.Path(data[0])) - - -def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]: - if isinstance(getter, str): - getter = functools.partial(_path_attribute_accessor, name=getter) - - return functools.partial(_path_accessor_closure, getter=getter) - - -def _path_comparator_closure(data: Tuple[str, Any], *, accessor: Callable[[Tuple[str, Any]], D], value: D) -> bool: - return accessor(data) == value - - -def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]: - return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value) - - -class PicklerDataPipe(IterDataPipe): - def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO[bytes]]]) -> None: - self.source_datapipe = source_datapipe - - def __iter__(self) -> Iterator[Any]: - for _, fobj in self.source_datapipe: - data = pickle.load(fobj) - for _, d in enumerate(data): - yield d - - -class SharderDataPipe(ShardingFilter): - def __init__(self, source_datapipe: IterDataPipe) -> None: - super().__init__(source_datapipe) - self.rank = 0 - self.world_size = 1 - if dist.is_available() and dist.is_initialized(): - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - self.apply_sharding(self.world_size, self.rank) - - def __iter__(self) -> Iterator[Any]: - num_workers = self.world_size - worker_id = self.rank - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - worker_id = worker_id + worker_info.id * num_workers - num_workers *= worker_info.num_workers - self.apply_sharding(num_workers, worker_id) - yield from super().__iter__() - - -class TakerDataPipe(IterDataPipe): - def __init__(self, source_datapipe: IterDataPipe, num_take: int) -> None: - super().__init__() - self.source_datapipe = source_datapipe - self.num_take = num_take - self.world_size = 1 - if dist.is_available() and dist.is_initialized(): - self.world_size = dist.get_world_size() - - def __iter__(self) -> Iterator[Any]: - num_workers = self.world_size - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - num_workers *= worker_info.num_workers - - # TODO: this is weird as it drops more elements than it should - num_take = self.num_take // num_workers - - for i, data in enumerate(self.source_datapipe): - if i < num_take: - yield data - else: - break - - def __len__(self) -> int: - num_take = self.num_take // self.world_size - if isinstance(self.source_datapipe, Sized): - if len(self.source_datapipe) < num_take: - num_take = len(self.source_datapipe) - # TODO: might be weird to not take `num_workers` into account - return num_take - - -def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[str, Any]]: - dp = IoPathFileLister(root=root) - dp = SharderDataPipe(dp) - dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE) - dp = IoPathFileOpener(dp, mode="rb") - dp = PicklerDataPipe(dp) - # dp = dp.cycle(2) - dp = TakerDataPipe(dp, dataset_size) - return dp - - -def read_flo(file: BinaryIO) -> torch.Tensor: - if file.read(4) != b"PIEH": - raise ValueError("Magic number incorrect. Invalid .flo file") - - width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2) - flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2) - return flow.reshape((height, width, 2)).permute((2, 0, 1)) - - -def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter: - return ShardingFilter(datapipe) - - -def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]: - return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) - - -def read_categories_file(name: str) -> List[Union[str, Sequence[str]]]: - path = BUILTIN_DIR / f"{name}.categories" - with open(path, newline="") as file: - rows = list(csv.reader(file)) - rows = [row[0] if len(row) == 1 else row for row in rows] - return rows diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py deleted file mode 100644 index af4ede38dc0..00000000000 --- a/torchvision/prototype/datasets/utils/_resource.py +++ /dev/null @@ -1,235 +0,0 @@ -import abc -import hashlib -import itertools -import pathlib -from typing import Any, Callable, IO, Literal, NoReturn, Optional, Sequence, Set, Tuple, Union -from urllib.parse import urlparse - -from torchdata.datapipes.iter import ( - FileLister, - FileOpener, - IterableWrapper, - IterDataPipe, - RarArchiveLoader, - TarArchiveLoader, - ZipArchiveLoader, -) -from torchvision.datasets.utils import ( - _decompress, - _detect_file_type, - _get_google_drive_file_id, - _get_redirect_url, - download_file_from_google_drive, - download_url, - extract_archive, -) - - -class OnlineResource(abc.ABC): - def __init__( - self, - *, - file_name: str, - sha256: Optional[str] = None, - preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None, - ) -> None: - self.file_name = file_name - self.sha256 = sha256 - - if isinstance(preprocess, str): - if preprocess == "decompress": - preprocess = self._decompress - elif preprocess == "extract": - preprocess = self._extract - else: - raise ValueError( - f"Only `'decompress'` or `'extract'` are valid if `preprocess` is passed as string," - f"but got {preprocess} instead." - ) - self._preprocess = preprocess - - @staticmethod - def _extract(file: pathlib.Path) -> None: - extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False) - - @staticmethod - def _decompress(file: pathlib.Path) -> None: - _decompress(str(file), remove_finished=True) - - def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: - if path.is_dir(): - return FileOpener(FileLister(str(path), recursive=True), mode="rb") - - dp = FileOpener(IterableWrapper((str(path),)), mode="rb") - - archive_loader = self._guess_archive_loader(path) - if archive_loader: - dp = archive_loader(dp) - - return dp - - _ARCHIVE_LOADERS = { - ".tar": TarArchiveLoader, - ".zip": ZipArchiveLoader, - ".rar": RarArchiveLoader, - } - - def _guess_archive_loader( - self, path: pathlib.Path - ) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]: - try: - _, archive_type, _ = _detect_file_type(path.name) - except RuntimeError: - return None - return self._ARCHIVE_LOADERS.get(archive_type) # type: ignore[arg-type] - - def load( - self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False - ) -> IterDataPipe[Tuple[str, IO]]: - root = pathlib.Path(root) - path = root / self.file_name - - # Instead of the raw file, there might also be files with fewer suffixes after decompression or directories - # with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which - # is not sufficient for files with multiple suffixes, e.g. foo.tar.gz. - stem = path.name.replace("".join(path.suffixes), "") - - def find_candidates() -> Set[pathlib.Path]: - # Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder - # candidate simultaneously, that would also pick up other files that share the same prefix. For example, the - # test split of the stanford-cars dataset uses the files - # - cars_test.tgz - # - cars_test_annos_withlabels.mat - # Globbing for `"cars_test*"` picks up both. - candidates = {file for file in path.parent.glob(f"{stem}.*")} - folder_candidate = path.parent / stem - if folder_candidate.exists(): - candidates.add(folder_candidate) - - return candidates - - candidates = find_candidates() - - if not candidates: - self.download(root, skip_integrity_check=skip_integrity_check) - if self._preprocess is not None: - self._preprocess(path) - candidates = find_candidates() - - # We use the path with the fewest suffixes. This gives us the - # extracted > decompressed > raw - # priority that we want for the best I/O performance. - return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes))) - - @abc.abstractmethod - def _download(self, root: pathlib.Path) -> None: - pass - - def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path: - root = pathlib.Path(root) - self._download(root) - path = root / self.file_name - if self.sha256 and not skip_integrity_check: - self._check_sha256(path) - return path - - def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None: - hash = hashlib.sha256() - with open(path, "rb") as file: - for chunk in iter(lambda: file.read(chunk_size), b""): - hash.update(chunk) - sha256 = hash.hexdigest() - if sha256 != self.sha256: - raise RuntimeError( - f"After the download, the SHA256 checksum of {path} didn't match the expected one: " - f"{sha256} != {self.sha256}" - ) - - -class HttpResource(OnlineResource): - def __init__( - self, url: str, *, file_name: Optional[str] = None, mirrors: Sequence[str] = (), **kwargs: Any - ) -> None: - super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs) - self.url = url - self.mirrors = mirrors - self._resolved = False - - def resolve(self) -> OnlineResource: - if self._resolved: - return self - - redirect_url = _get_redirect_url(self.url) - if redirect_url == self.url: - self._resolved = True - return self - - meta = { - attr.lstrip("_"): getattr(self, attr) - for attr in ( - "file_name", - "sha256", - "_preprocess", - ) - } - - gdrive_id = _get_google_drive_file_id(redirect_url) - if gdrive_id: - return GDriveResource(gdrive_id, **meta) - - http_resource = HttpResource(redirect_url, **meta) - http_resource._resolved = True - return http_resource - - def _download(self, root: pathlib.Path) -> None: - if not self._resolved: - return self.resolve()._download(root) - - for url in itertools.chain((self.url,), self.mirrors): - - try: - download_url(url, str(root), filename=self.file_name, md5=None) - # TODO: make this more precise - except Exception: - continue - - return - else: - # TODO: make this more informative - raise RuntimeError("Download failed!") - - -class GDriveResource(OnlineResource): - def __init__(self, id: str, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.id = id - - def _download(self, root: pathlib.Path) -> None: - download_file_from_google_drive(self.id, root=str(root), filename=self.file_name, md5=None) - - -class ManualDownloadResource(OnlineResource): - def __init__(self, instructions: str, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.instructions = instructions - - def _download(self, root: pathlib.Path) -> NoReturn: - raise RuntimeError( - f"The file {self.file_name} cannot be downloaded automatically. " - f"Please follow the instructions below and place it in {root}\n\n" - f"{self.instructions}" - ) - - -class KaggleDownloadResource(ManualDownloadResource): - def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None: - instructions = "\n".join( - ( - "1. Register and login at https://www.kaggle.com", - f"2. Navigate to {challenge_url}", - "3. Click 'Join Competition' and follow the instructions there", - "4. Navigate to the 'Data' tab", - f"5. Select {file_name} in the 'Data Explorer' and click the download button", - ) - ) - super().__init__(instructions, file_name=file_name, **kwargs) diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py deleted file mode 100644 index 8b8eda9e9d2..00000000000 --- a/torchvision/prototype/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import depth diff --git a/torchvision/prototype/models/depth/__init__.py b/torchvision/prototype/models/depth/__init__.py deleted file mode 100644 index 0ff02953c24..00000000000 --- a/torchvision/prototype/models/depth/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import stereo diff --git a/torchvision/prototype/models/depth/stereo/__init__.py b/torchvision/prototype/models/depth/stereo/__init__.py deleted file mode 100644 index cd075ca2b9e..00000000000 --- a/torchvision/prototype/models/depth/stereo/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .raft_stereo import * -from .crestereo import * diff --git a/torchvision/prototype/models/depth/stereo/crestereo.py b/torchvision/prototype/models/depth/stereo/crestereo.py deleted file mode 100644 index 89a23aae7f2..00000000000 --- a/torchvision/prototype/models/depth/stereo/crestereo.py +++ /dev/null @@ -1,1463 +0,0 @@ -import math -from functools import partial -from typing import Callable, Dict, Iterable, List, Optional, Tuple - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.models.optical_flow.raft as raft -from torch import Tensor -from torchvision.models._api import register_model, Weights, WeightsEnum -from torchvision.models._utils import handle_legacy_interface -from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow -from torchvision.ops import Conv2dNormActivation -from torchvision.prototype.transforms._presets import StereoMatching - -all = ( - "CREStereo", - "CREStereo_Base_Weights", - "crestereo_base", -) - - -class ConvexMaskPredictor(nn.Module): - def __init__( - self, - *, - in_channels: int, - hidden_size: int, - upsample_factor: int, - multiplier: float = 0.25, - ) -> None: - - super().__init__() - self.mask_head = nn.Sequential( - Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3), - # https://arxiv.org/pdf/2003.12039.pdf (Annex section B) for the - # following convolution output size - nn.Conv2d(hidden_size, upsample_factor**2 * 9, 1, padding=0), - ) - - self.multiplier = multiplier - - def forward(self, x: Tensor) -> Tensor: - x = self.mask_head(x) * self.multiplier - return x - - -def get_correlation( - left_feature: Tensor, - right_feature: Tensor, - window_size: Tuple[int, int] = (3, 3), - dilate: Tuple[int, int] = (1, 1), -) -> Tensor: - """Function that computes a correlation product between the left and right features. - - The correlation is computed in a sliding window fashion, namely the left features are fixed - and for each ``(i, j)`` location we compute the correlation with a sliding window anchored in - ``(i, j)`` from the right feature map. The sliding window selects pixels obtained in the range of the sliding - window; i.e ``(i - window_size // 2, i + window_size // 2)`` respectively ``(j - window_size // 2, j + window_size // 2)``. - """ - - B, C, H, W = left_feature.shape - - di_y, di_x = dilate[0], dilate[1] - pad_y, pad_x = window_size[0] // 2 * di_y, window_size[1] // 2 * di_x - - right_padded = F.pad(right_feature, (pad_x, pad_x, pad_y, pad_y), mode="replicate") - # in order to vectorize the correlation computation over all pixel candidates - # we create multiple shifted right images which we stack on an extra dimension - right_padded = F.unfold(right_padded, kernel_size=(H, W), dilation=dilate) - # torch unfold returns a tensor of shape [B, flattened_values, n_selections] - right_padded = right_padded.permute(0, 2, 1) - # we consider rehsape back into [B, n_views, C, H, W] - right_padded = right_padded.reshape(B, (window_size[0] * window_size[1]), C, H, W) - # we expand the left features for broadcasting - left_feature = left_feature.unsqueeze(1) - # this will compute an element product of between [B, 1, C, H, W] * [B, n_views, C, H, W] - # to obtain correlations over the pixel candidates we perform a mean on the C dimension - correlation = torch.mean(left_feature * right_padded, dim=2, keepdim=False) - # the final correlation tensor shape will be [B, n_views, H, W] - # where on the i-th position of the n_views dimension we will have - # the correlation value between the left pixel - # and the i-th candidate on the right feature map - return correlation - - -def _check_window_specs( - search_window_1d: Tuple[int, int] = (1, 9), - search_dilate_1d: Tuple[int, int] = (1, 1), - search_window_2d: Tuple[int, int] = (3, 3), - search_dilate_2d: Tuple[int, int] = (1, 1), -) -> None: - - if not np.prod(search_window_1d) == np.prod(search_window_2d): - raise ValueError( - f"The 1D and 2D windows should contain the same number of elements. " - f"1D shape: {search_window_1d} 2D shape: {search_window_2d}" - ) - if not np.prod(search_window_1d) % 2 == 1: - raise ValueError( - f"Search windows should contain an odd number of elements in them." - f"Window of shape {search_window_1d} has {np.prod(search_window_1d)} elements." - ) - if not any(size == 1 for size in search_window_1d): - raise ValueError(f"The 1D search window should have at least one size equal to 1. 1D shape: {search_window_1d}") - if any(size == 1 for size in search_window_2d): - raise ValueError( - f"The 2D search window should have all dimensions greater than 1. 2D shape: {search_window_2d}" - ) - if any(dilate < 1 for dilate in search_dilate_1d): - raise ValueError( - f"The 1D search dilation should have all elements equal or greater than 1. 1D shape: {search_dilate_1d}" - ) - if any(dilate < 1 for dilate in search_dilate_2d): - raise ValueError( - f"The 2D search dilation should have all elements equal greater than 1. 2D shape: {search_dilate_2d}" - ) - - -class IterativeCorrelationLayer(nn.Module): - def __init__( - self, - groups: int = 4, - search_window_1d: Tuple[int, int] = (1, 9), - search_dilate_1d: Tuple[int, int] = (1, 1), - search_window_2d: Tuple[int, int] = (3, 3), - search_dilate_2d: Tuple[int, int] = (1, 1), - ) -> None: - - super().__init__() - _check_window_specs( - search_window_1d=search_window_1d, - search_dilate_1d=search_dilate_1d, - search_window_2d=search_window_2d, - search_dilate_2d=search_dilate_2d, - ) - self.search_pixels = np.prod(search_window_1d) - self.groups = groups - - # two selection tables for dealing with the small_patch argument in the forward function - self.patch_sizes = { - "2d": [search_window_2d for _ in range(self.groups)], - "1d": [search_window_1d for _ in range(self.groups)], - } - - self.dilate_sizes = { - "2d": [search_dilate_2d for _ in range(self.groups)], - "1d": [search_dilate_1d for _ in range(self.groups)], - } - - def forward(self, left_feature: Tensor, right_feature: Tensor, flow: Tensor, window_type: str = "1d") -> Tensor: - """Function that computes 1 pass of non-offsetted Group-Wise correlation""" - coords = make_coords_grid( - left_feature.shape[0], left_feature.shape[2], left_feature.shape[3], device=str(left_feature.device) - ) - - # we offset the coordinate grid in the flow direction - coords = coords + flow - coords = coords.permute(0, 2, 3, 1) - # resample right features according to off-setted grid - right_feature = grid_sample(right_feature, coords, mode="bilinear", align_corners=True) - - # use_small_patch is a flag by which we decide on how many axes - # we perform candidate search. See section 3.1 ``Deformable search window`` & Figure 4 in the paper. - patch_size_list = self.patch_sizes[window_type] - dilate_size_list = self.dilate_sizes[window_type] - - # chunking the left and right feature to perform group-wise correlation - # mechanism similar to GroupNorm. See section 3.1 ``Group-wise correlation``. - left_groups = torch.chunk(left_feature, self.groups, dim=1) - right_groups = torch.chunk(right_feature, self.groups, dim=1) - - correlations = [] - # this boils down to rather than performing the correlation product - # over the entire C dimensions, we use subsets of C to get multiple correlation sets - for i in range(len(patch_size_list)): - correlation = get_correlation(left_groups[i], right_groups[i], patch_size_list[i], dilate_size_list[i]) - correlations.append(correlation) - final_correlations = torch.cat(correlations, dim=1) - return final_correlations - - -class AttentionOffsetCorrelationLayer(nn.Module): - def __init__( - self, - groups: int = 4, - attention_module: Optional[nn.Module] = None, - search_window_1d: Tuple[int, int] = (1, 9), - search_dilate_1d: Tuple[int, int] = (1, 1), - search_window_2d: Tuple[int, int] = (3, 3), - search_dilate_2d: Tuple[int, int] = (1, 1), - ) -> None: - super().__init__() - _check_window_specs( - search_window_1d=search_window_1d, - search_dilate_1d=search_dilate_1d, - search_window_2d=search_window_2d, - search_dilate_2d=search_dilate_2d, - ) - # convert to python scalar - self.search_pixels = int(np.prod(search_window_1d)) - self.groups = groups - - # two selection tables for dealing with the small_patch argument in the forward function - self.patch_sizes = { - "2d": [search_window_2d for _ in range(self.groups)], - "1d": [search_window_1d for _ in range(self.groups)], - } - - self.dilate_sizes = { - "2d": [search_dilate_2d for _ in range(self.groups)], - "1d": [search_dilate_1d for _ in range(self.groups)], - } - - self.attention_module = attention_module - - def forward( - self, - left_feature: Tensor, - right_feature: Tensor, - flow: Tensor, - extra_offset: Tensor, - window_type: str = "1d", - ) -> Tensor: - """Function that computes 1 pass of offsetted Group-Wise correlation - - If the class was provided with an attention layer, the left and right feature maps - will be passed through a transformer first - """ - B, C, H, W = left_feature.shape - - if self.attention_module is not None: - # prepare for transformer required input shapes - left_feature = left_feature.permute(0, 2, 3, 1).reshape(B, H * W, C) - right_feature = right_feature.permute(0, 2, 3, 1).reshape(B, H * W, C) - # this can be either self attention or cross attention, hence the tuple return - left_feature, right_feature = self.attention_module(left_feature, right_feature) - left_feature = left_feature.reshape(B, H, W, C).permute(0, 3, 1, 2) - right_feature = right_feature.reshape(B, H, W, C).permute(0, 3, 1, 2) - - left_groups = torch.chunk(left_feature, self.groups, dim=1) - right_groups = torch.chunk(right_feature, self.groups, dim=1) - - num_search_candidates = self.search_pixels - # for each pixel (i, j) we have a number of search candidates - # thus, for each candidate we should have an X-axis and Y-axis offset value - extra_offset = extra_offset.reshape(B, num_search_candidates, 2, H, W).permute(0, 1, 3, 4, 2) - - patch_size_list = self.patch_sizes[window_type] - dilate_size_list = self.dilate_sizes[window_type] - - group_channels = C // self.groups - correlations = [] - - for i in range(len(patch_size_list)): - left_group, right_group = left_groups[i], right_groups[i] - patch_size, dilate = patch_size_list[i], dilate_size_list[i] - - di_y, di_x = dilate - ps_y, ps_x = patch_size - # define the search based on the window patch shape - ry, rx = ps_y // 2 * di_y, ps_x // 2 * di_x - - # base offsets for search (i.e. where to look on the search index) - x_grid, y_grid = torch.meshgrid( - torch.arange(-rx, rx + 1, di_x), torch.arange(-ry, ry + 1, di_y), indexing="xy" - ) - x_grid, y_grid = x_grid.to(flow.device), y_grid.to(flow.device) - offsets = torch.stack((x_grid, y_grid)) - offsets = offsets.reshape(2, -1).permute(1, 0) - - for d in (0, 2, 3): - offsets = offsets.unsqueeze(d) - # extra offsets for search (i.e. deformed search indexes. Similar concept to deformable convolutions) - offsets = offsets + extra_offset - - coords = ( - make_coords_grid( - left_feature.shape[0], left_feature.shape[2], left_feature.shape[3], device=str(left_feature.device) - ) - + flow - ) - coords = coords.permute(0, 2, 3, 1).unsqueeze(1) - coords = coords + offsets - coords = coords.reshape(B, -1, W, 2) - - right_group = grid_sample(right_group, coords, mode="bilinear", align_corners=True) - # we do not need to perform any window shifting because the grid sample op - # will return a multi-view right based on the num_search_candidates dimension in the offsets - right_group = right_group.reshape(B, group_channels, -1, H, W) - left_group = left_group.reshape(B, group_channels, -1, H, W) - correlation = torch.mean(left_group * right_group, dim=1) - correlations.append(correlation) - - final_correlation = torch.cat(correlations, dim=1) - return final_correlation - - -class AdaptiveGroupCorrelationLayer(nn.Module): - """ - Container for computing various correlation types between a left and right feature map. - This module does not contain any optimisable parameters, it's solely a collection of ops. - We wrap in a nn.Module for torch.jit.script compatibility - - Adaptive Group Correlation operations from: https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf - - Canonical reference implementation: https://github.com/megvii-research/CREStereo/blob/master/nets/corr.py - """ - - def __init__( - self, - iterative_correlation_layer: IterativeCorrelationLayer, - attention_offset_correlation_layer: AttentionOffsetCorrelationLayer, - ) -> None: - super().__init__() - - self.iterative_correlation_layer = iterative_correlation_layer - self.attention_offset_correlation_layer = attention_offset_correlation_layer - - def forward( - self, - left_features: Tensor, - right_features: Tensor, - flow: torch.Tensor, - extra_offset: Optional[Tensor], - window_type: str = "1d", - iter_mode: bool = False, - ) -> Tensor: - if iter_mode or extra_offset is None: - corr = self.iterative_correlation_layer(left_features, right_features, flow, window_type) - else: - corr = self.attention_offset_correlation_layer( - left_features, right_features, flow, extra_offset, window_type - ) # type: ignore - return corr - - -def elu_feature_map(x: Tensor) -> Tensor: - """Elu feature map operation from: https://arxiv.org/pdf/2006.16236.pdf""" - return F.elu(x) + 1 - - -class LinearAttention(nn.Module): - """ - Linear attention operation from: https://arxiv.org/pdf/2006.16236.pdf - Canonical implementation reference: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py - LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py - """ - - def __init__(self, eps: float = 1e-6, feature_map_fn: Callable[[Tensor], Tensor] = elu_feature_map) -> None: - super().__init__() - self.eps = eps - self.feature_map_fn = feature_map_fn - - def forward( - self, - queries: Tensor, - keys: Tensor, - values: Tensor, - q_mask: Optional[Tensor] = None, - kv_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Args: - queries (torch.Tensor): [N, S1, H, D] - keys (torch.Tensor): [N, S2, H, D] - values (torch.Tensor): [N, S2, H, D] - q_mask (torch.Tensor): [N, S1] (optional) - kv_mask (torch.Tensor): [N, S2] (optional) - Returns: - queried_values (torch.Tensor): [N, S1, H, D] - """ - queries = self.feature_map_fn(queries) - keys = self.feature_map_fn(keys) - - if q_mask is not None: - queries = queries * q_mask[:, :, None, None] - if kv_mask is not None: - keys = keys * kv_mask[:, :, None, None] - values = values * kv_mask[:, :, None, None] - - # mitigates fp16 overflows - values_length = values.shape[1] - values = values / values_length - kv = torch.einsum("NSHD, NSHV -> NHDV", keys, values) - z = 1 / (torch.einsum("NLHD, NHD -> NLH", queries, keys.sum(dim=1)) + self.eps) - # rescale at the end to account for fp16 mitigation - queried_values = torch.einsum("NLHD, NHDV, NLH -> NLHV", queries, kv, z) * values_length - return queried_values - - -class SoftmaxAttention(nn.Module): - """ - A simple softmax attention operation - LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py - """ - - def __init__(self, dropout: float = 0.0) -> None: - super().__init__() - self.dropout = nn.Dropout(dropout) if dropout else nn.Identity() - - def forward( - self, - queries: Tensor, - keys: Tensor, - values: Tensor, - q_mask: Optional[Tensor] = None, - kv_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Computes classical softmax full-attention between all queries and keys. - - Args: - queries (torch.Tensor): [N, S1, H, D] - keys (torch.Tensor): [N, S2, H, D] - values (torch.Tensor): [N, S2, H, D] - q_mask (torch.Tensor): [N, S1] (optional) - kv_mask (torch.Tensor): [N, S2] (optional) - Returns: - queried_values: [N, S1, H, D] - """ - - scale_factor = 1.0 / queries.shape[3] ** 0.5 # irsqrt(D) scaling - queries = queries * scale_factor - - qk = torch.einsum("NLHD, NSHD -> NLSH", queries, keys) - if kv_mask is not None and q_mask is not None: - qk.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float("-inf")) - - attention = torch.softmax(qk, dim=2) - attention = self.dropout(attention) - - queried_values = torch.einsum("NLSH, NSHD -> NLHD", attention, values) - return queried_values - - -class PositionalEncodingSine(nn.Module): - """ - Sinusoidal positional encodings - - Using the scaling term from https://github.com/megvii-research/CREStereo/blob/master/nets/attention/position_encoding.py - Reference implementation from https://github.com/facebookresearch/detr/blob/8a144f83a287f4d3fece4acdf073f387c5af387d/models/position_encoding.py#L28-L48 - """ - - def __init__(self, dim_model: int, max_size: int = 256) -> None: - super().__init__() - self.dim_model = dim_model - self.max_size = max_size - # pre-registered for memory efficiency during forward pass - pe = self._make_pe_of_size(self.max_size) - self.register_buffer("pe", pe) - - def _make_pe_of_size(self, size: int) -> Tensor: - pe = torch.zeros((self.dim_model, *(size, size)), dtype=torch.float32) - y_positions = torch.ones((size, size)).cumsum(0).float().unsqueeze(0) - x_positions = torch.ones((size, size)).cumsum(1).float().unsqueeze(0) - div_term = torch.exp(torch.arange(0.0, self.dim_model // 2, 2) * (-math.log(10000.0) / self.dim_model // 2)) - div_term = div_term[:, None, None] - pe[0::4, :, :] = torch.sin(x_positions * div_term) - pe[1::4, :, :] = torch.cos(x_positions * div_term) - pe[2::4, :, :] = torch.sin(y_positions * div_term) - pe[3::4, :, :] = torch.cos(y_positions * div_term) - pe = pe.unsqueeze(0) - return pe - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: [B, C, H, W] - - Returns: - x: [B, C, H, W] - """ - torch._assert( - len(x.shape) == 4, - f"PositionalEncodingSine requires a 4-D dimensional input. Provided tensor is of shape {x.shape}", - ) - - B, C, H, W = x.shape - return x + self.pe[:, :, :H, :W] # type: ignore - - -class LocalFeatureEncoderLayer(nn.Module): - """ - LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf - Canonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py - """ - - def __init__( - self, - *, - dim_model: int, - num_heads: int, - attention_module: Callable[..., nn.Module] = LinearAttention, - ) -> None: - super().__init__() - - self.attention_op = attention_module() - - if not isinstance(self.attention_op, (LinearAttention, SoftmaxAttention)): - raise ValueError( - f"attention_module must be an instance of LinearAttention or SoftmaxAttention. Got {type(self.attention_op)}" - ) - - self.dim_head = dim_model // num_heads - self.num_heads = num_heads - - # multi-head attention - self.query_proj = nn.Linear(dim_model, dim_model, bias=False) - self.key_proj = nn.Linear(dim_model, dim_model, bias=False) - self.value_proj = nn.Linear(dim_model, dim_model, bias=False) - self.merge = nn.Linear(dim_model, dim_model, bias=False) - - # feed forward network - self.ffn = nn.Sequential( - nn.Linear(dim_model * 2, dim_model * 2, bias=False), - nn.ReLU(), - nn.Linear(dim_model * 2, dim_model, bias=False), - ) - - # norm layers - self.attention_norm = nn.LayerNorm(dim_model) - self.ffn_norm = nn.LayerNorm(dim_model) - - def forward( - self, x: Tensor, source: Tensor, x_mask: Optional[Tensor] = None, source_mask: Optional[Tensor] = None - ) -> Tensor: - """ - Args: - x (torch.Tensor): [B, S1, D] - source (torch.Tensor): [B, S2, D] - x_mask (torch.Tensor): [B, S1] (optional) - source_mask (torch.Tensor): [B, S2] (optional) - """ - B, S, D = x.shape - queries, keys, values = x, source, source - - queries = self.query_proj(queries).reshape(B, S, self.num_heads, self.dim_head) - keys = self.key_proj(keys).reshape(B, S, self.num_heads, self.dim_head) - values = self.value_proj(values).reshape(B, S, self.num_heads, self.dim_head) - - # attention operation - message = self.attention_op(queries, keys, values, x_mask, source_mask) - # concatenating attention heads together before passing through projection layer - message = self.merge(message.reshape(B, S, D)) - message = self.attention_norm(message) - - # ffn operation - message = self.ffn(torch.cat([x, message], dim=2)) - message = self.ffn_norm(message) - - return x + message - - -class LocalFeatureTransformer(nn.Module): - """ - LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf - Canonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py - """ - - def __init__( - self, - *, - dim_model: int, - num_heads: int, - attention_directions: List[str], - attention_module: Callable[..., nn.Module] = LinearAttention, - ) -> None: - super(LocalFeatureTransformer, self).__init__() - - self.attention_module = attention_module - self.attention_directions = attention_directions - for direction in attention_directions: - if direction not in ["self", "cross"]: - raise ValueError( - f"Attention direction {direction} unsupported. LocalFeatureTransformer accepts only ``attention_type`` in ``[self, cross]``." - ) - - self.layers = nn.ModuleList( - [ - LocalFeatureEncoderLayer(dim_model=dim_model, num_heads=num_heads, attention_module=attention_module) - for _ in attention_directions - ] - ) - - def forward( - self, - left_features: Tensor, - right_features: Tensor, - left_mask: Optional[Tensor] = None, - right_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - left_features (torch.Tensor): [N, S1, D] - right_features (torch.Tensor): [N, S2, D] - left_mask (torch.Tensor): [N, S1] (optional) - right_mask (torch.Tensor): [N, S2] (optional) - Returns: - left_features (torch.Tensor): [N, S1, D] - right_features (torch.Tensor): [N, S2, D] - """ - - torch._assert( - left_features.shape[2] == right_features.shape[2], - f"left_features and right_features should have the same embedding dimensions. left_features: {left_features.shape[2]} right_features: {right_features.shape[2]}", - ) - - for idx, layer in enumerate(self.layers): - attention_direction = self.attention_directions[idx] - - if attention_direction == "self": - left_features = layer(left_features, left_features, left_mask, left_mask) - right_features = layer(right_features, right_features, right_mask, right_mask) - - elif attention_direction == "cross": - left_features = layer(left_features, right_features, left_mask, right_mask) - right_features = layer(right_features, left_features, right_mask, left_mask) - - return left_features, right_features - - -class PyramidDownsample(nn.Module): - """ - A simple wrapper that return and Avg Pool feature pyramid based on the provided scales. - Implicitly returns the input as well. - """ - - def __init__(self, factors: Iterable[int]) -> None: - super().__init__() - self.factors = factors - - def forward(self, x: torch.Tensor) -> List[Tensor]: - results = [x] - for factor in self.factors: - results.append(F.avg_pool2d(x, kernel_size=factor, stride=factor)) - return results - - -class CREStereo(nn.Module): - """ - Implements CREStereo from the `"Practical Stereo Matching via Cascaded Recurrent Network - With Adaptive Correlation" <https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf>`_ paper. - Args: - feature_encoder (raft.FeatureEncoder): Raft-like Feature Encoder module extract low-level features from inputs. - update_block (raft.UpdateBlock): Raft-like Update Block which recursively refines a flow-map. - flow_head (raft.FlowHead): Raft-like Flow Head which predics a flow-map from some inputs. - self_attn_block (LocalFeatureTransformer): A Local Feature Transformer that performs self attention on the two feature maps. - cross_attn_block (LocalFeatureTransformer): A Local Feature Transformer that performs cross attention between the two feature maps - used in the Adaptive Group Correlation module. - feature_downsample_rates (List[int]): The downsample rates used to build a feature pyramid from the outputs of the `feature_encoder`. Default: [2, 4] - correlation_groups (int): In how many groups should the features be split when computer per-pixel correlation. Defaults 4. - search_window_1d (Tuple[int, int]): The alternate search window size in the x and y directions for the 1D case. Defaults to (1, 9). - search_dilate_1d (Tuple[int, int]): The dilation used in the `search_window_1d` when selecting pixels. Similar to `nn.Conv2d` dilate. Defaults to (1, 1). - search_window_2d (Tuple[int, int]): The alternate search window size in the x and y directions for the 2D case. Defaults to (3, 3). - search_dilate_2d (Tuple[int, int]): The dilation used in the `search_window_2d` when selecting pixels. Similar to `nn.Conv2d` dilate. Defaults to (1, 1). - """ - - def __init__( - self, - *, - feature_encoder: raft.FeatureEncoder, - update_block: raft.UpdateBlock, - flow_head: raft.FlowHead, - self_attn_block: LocalFeatureTransformer, - cross_attn_block: LocalFeatureTransformer, - feature_downsample_rates: Tuple[int, ...] = (2, 4), - correlation_groups: int = 4, - search_window_1d: Tuple[int, int] = (1, 9), - search_dilate_1d: Tuple[int, int] = (1, 1), - search_window_2d: Tuple[int, int] = (3, 3), - search_dilate_2d: Tuple[int, int] = (1, 1), - ) -> None: - super().__init__() - self.output_channels = 2 - - self.feature_encoder = feature_encoder - self.update_block = update_block - self.flow_head = flow_head - self.self_attn_block = self_attn_block - - # average pooling for the feature encoder outputs - self.downsampling_pyramid = PyramidDownsample(feature_downsample_rates) - self.downsampling_factors: List[int] = [feature_encoder.downsample_factor] - base_downsample_factor: int = self.downsampling_factors[0] - for rate in feature_downsample_rates: - self.downsampling_factors.append(base_downsample_factor * rate) - - # output resolution tracking - self.resolutions: List[str] = [f"1 / {factor}" for factor in self.downsampling_factors] - self.search_pixels = int(np.prod(search_window_1d)) - - # flow convex upsampling mask predictor - self.mask_predictor = ConvexMaskPredictor( - in_channels=feature_encoder.output_dim // 2, - hidden_size=feature_encoder.output_dim, - upsample_factor=feature_encoder.downsample_factor, - multiplier=0.25, - ) - - # offsets modules for offsetted feature selection - self.offset_convs = nn.ModuleDict() - self.correlation_layers = nn.ModuleDict() - - offset_conv_layer = partial( - Conv2dNormActivation, - in_channels=feature_encoder.output_dim, - out_channels=self.search_pixels * 2, - norm_layer=None, - activation_layer=None, - ) - - # populate the dicts in top to bottom order - # useful for iterating through torch.jit.script module given the network forward pass - # - # Ignore the largest resolution. We handle that separately due to torch.jit.script - # not being able to access to runtime generated keys in ModuleDicts. - # This way, we can keep a generic way of processing all pyramid levels but except - # the final one - iterative_correlation_layer = partial( - IterativeCorrelationLayer, - groups=correlation_groups, - search_window_1d=search_window_1d, - search_dilate_1d=search_dilate_1d, - search_window_2d=search_window_2d, - search_dilate_2d=search_dilate_2d, - ) - - attention_offset_correlation_layer = partial( - AttentionOffsetCorrelationLayer, - groups=correlation_groups, - search_window_1d=search_window_1d, - search_dilate_1d=search_dilate_1d, - search_window_2d=search_window_2d, - search_dilate_2d=search_dilate_2d, - ) - - for idx, resolution in enumerate(reversed(self.resolutions[1:])): - # the largest resolution does use offset convolutions for sampling grid coords - offset_conv = None if idx == len(self.resolutions) - 1 else offset_conv_layer() - if offset_conv: - self.offset_convs[resolution] = offset_conv - # only the lowest resolution uses the cross attention module when computing correlation scores - attention_module = cross_attn_block if idx == 0 else None - self.correlation_layers[resolution] = AdaptiveGroupCorrelationLayer( - iterative_correlation_layer=iterative_correlation_layer(), - attention_offset_correlation_layer=attention_offset_correlation_layer( - attention_module=attention_module - ), - ) - - # correlation layer for the largest resolution - self.max_res_correlation_layer = AdaptiveGroupCorrelationLayer( - iterative_correlation_layer=iterative_correlation_layer(), - attention_offset_correlation_layer=attention_offset_correlation_layer(), - ) - - # simple 2D Postional Encodings - self.positional_encodings = PositionalEncodingSine(feature_encoder.output_dim) - - def _get_window_type(self, iteration: int) -> str: - return "1d" if iteration % 2 == 0 else "2d" - - def forward( - self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 10 - ) -> List[Tensor]: - features = torch.cat([left_image, right_image], dim=0) - features = self.feature_encoder(features) - left_features, right_features = features.chunk(2, dim=0) - - # update block network state and input context are derived from the left feature map - net, ctx = left_features.chunk(2, dim=1) - net = torch.tanh(net) - ctx = torch.relu(ctx) - - # will output lists of tensor. - l_pyramid = self.downsampling_pyramid(left_features) - r_pyramid = self.downsampling_pyramid(right_features) - net_pyramid = self.downsampling_pyramid(net) - ctx_pyramid = self.downsampling_pyramid(ctx) - - # we store in reversed order because we process the pyramid from top to bottom - l_pyramid = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)} - r_pyramid = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)} - net_pyramid = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)} - ctx_pyramid = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)} - - # offsets for sampling pixel candidates in the correlation ops - offsets: Dict[str, Tensor] = {} - for resolution, offset_conv in self.offset_convs.items(): - feature_map = l_pyramid[resolution] - offset = offset_conv(feature_map) - offsets[resolution] = (torch.sigmoid(offset) - 0.5) * 2.0 - - # the smallest resolution is prepared for passing through self attention - min_res = self.resolutions[-1] - max_res = self.resolutions[0] - - B, C, MIN_H, MIN_W = l_pyramid[min_res].shape - # add positional encodings - l_pyramid[min_res] = self.positional_encodings(l_pyramid[min_res]) - r_pyramid[min_res] = self.positional_encodings(r_pyramid[min_res]) - # reshaping for transformer - l_pyramid[min_res] = l_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C) - r_pyramid[min_res] = r_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C) - # perform self attention - l_pyramid[min_res], r_pyramid[min_res] = self.self_attn_block(l_pyramid[min_res], r_pyramid[min_res]) - # now we need to reshape back into [B, C, H, W] format - l_pyramid[min_res] = l_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2) - r_pyramid[min_res] = r_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2) - - predictions: List[Tensor] = [] - flow_estimates: Dict[str, Tensor] = {} - # we added this because of torch.script.jit - # also, the predicition prior is always going to have the - # spatial size of the features outputted by the feature encoder - flow_pred_prior: Tensor = torch.empty( - size=(B, 2, left_features.shape[2], left_features.shape[3]), - dtype=l_pyramid[max_res].dtype, - device=l_pyramid[max_res].device, - ) - - if flow_init is not None: - scale = l_pyramid[max_res].shape[2] / flow_init.shape[2] - # in CREStereo implementation they multiply with -scale instead of scale - # this can be either a downsample or an upsample based on the cascaded inference - # configuration - - # we use a -scale because the flow used inside the network is a negative flow - # from the right to the left, so we flip the flow direction - flow_estimates[max_res] = -scale * F.interpolate( - input=flow_init, - size=l_pyramid[max_res].shape[2:], - mode="bilinear", - align_corners=True, - ) - - # when not provided with a flow prior, we construct one using the lower resolution maps - else: - # initialize a zero flow with the smallest resolution - flow = torch.zeros(size=(B, 2, MIN_H, MIN_W), device=left_features.device, dtype=left_features.dtype) - - # flows from coarse resolutions are refined similarly - # we always need to fetch the next pyramid feature map as well - # when updating coarse resolutions, therefore we create a reversed - # view which has its order synced with the ModuleDict keys iterator - coarse_resolutions: List[str] = self.resolutions[::-1] # using slicing because of torch.jit.script - fine_grained_resolution = max_res - - # set the coarsest flow to the zero flow - flow_estimates[coarse_resolutions[0]] = flow - - # the correlation layer ModuleDict will contain layers ordered from coarse to fine resolution - # i.e ["1 / 16", "1 / 8", "1 / 4"] - # the correlation layer ModuleDict has layers for all the resolutions except the fine one - # i.e {"1 / 16": Module, "1 / 8": Module} - # for these resolution we perform only half of the number of refinement iterations - for idx, (resolution, correlation_layer) in enumerate(self.correlation_layers.items()): - # compute the scale difference between the first pyramid scale and the current pyramid scale - scale_to_base = l_pyramid[fine_grained_resolution].shape[2] // l_pyramid[resolution].shape[2] - for it in range(num_iters // 2): - # set whether we want to search on (X, Y) axes for correlation or just on X axis - window_type = self._get_window_type(it) - # we consider this a prior, therefore we do not want to back-propagate through it - flow_estimates[resolution] = flow_estimates[resolution].detach() - - correlations = correlation_layer( - l_pyramid[resolution], # left - r_pyramid[resolution], # right - flow_estimates[resolution], - offsets[resolution], - window_type, - ) - - # update the recurrent network state and the flow deltas - net_pyramid[resolution], delta_flow = self.update_block( - net_pyramid[resolution], ctx_pyramid[resolution], correlations, flow_estimates[resolution] - ) - - # the convex upsampling weights are computed w.r.t. - # the recurrent update state - up_mask = self.mask_predictor(net_pyramid[resolution]) - flow_estimates[resolution] = flow_estimates[resolution] + delta_flow - # convex upsampling with the initial feature encoder downsampling rate - flow_pred_prior = upsample_flow( - flow_estimates[resolution], up_mask, factor=self.downsampling_factors[0] - ) - # we then bilinear upsample to the final resolution - # we use a factor that's equivalent to the difference between - # the current downsample resolution and the base downsample resolution - # - # i.e. if a 1 / 16 flow is upsampled by 4 (base downsampling) we get a 1 / 4 flow. - # therefore we have to further upscale it by the difference between - # the current level 1 / 16 and the base level 1 / 4. - # - # we use a -scale because the flow used inside the network is a negative flow - # from the right to the left, so we flip the flow direction in order to get the - # left to right flow - flow_pred = -upsample_flow(flow_pred_prior, None, factor=scale_to_base) - predictions.append(flow_pred) - - # when constructing the next resolution prior, we resample w.r.t - # to the scale of the next level in the pyramid - next_resolution = coarse_resolutions[idx + 1] - scale_to_next = l_pyramid[next_resolution].shape[2] / flow_pred_prior.shape[2] - # we use the flow_up_prior because this is a more accurate estimation of the true flow - # due to the convex upsample, which resembles a learned super-resolution module. - # this is not necessarily an upsample, it can be a downsample, based on the provided configuration - flow_estimates[next_resolution] = -scale_to_next * F.interpolate( - input=flow_pred_prior, - size=l_pyramid[next_resolution].shape[2:], - mode="bilinear", - align_corners=True, - ) - - # finally we will be doing a full pass through the fine-grained resolution - # this coincides with the maximum resolution - - # we keep a separate loop here in order to avoid python control flow - # to decide how many iterations should we do based on the current resolution - # furthermore, if provided with an initial flow, there is no need to generate - # a prior estimate when moving into the final refinement stage - - for it in range(num_iters): - search_window_type = self._get_window_type(it) - - flow_estimates[max_res] = flow_estimates[max_res].detach() - # we run the fine-grained resolution correlations in iterative mode - # this means that we are using the fixed window pixel selections - # instead of the deformed ones as with the previous steps - correlations = self.max_res_correlation_layer( - l_pyramid[max_res], - r_pyramid[max_res], - flow_estimates[max_res], - extra_offset=None, - window_type=search_window_type, - iter_mode=True, - ) - - net_pyramid[max_res], delta_flow = self.update_block( - net_pyramid[max_res], ctx_pyramid[max_res], correlations, flow_estimates[max_res] - ) - - up_mask = self.mask_predictor(net_pyramid[max_res]) - flow_estimates[max_res] = flow_estimates[max_res] + delta_flow - # at the final resolution we simply do a convex upsample using the base downsample rate - flow_pred = -upsample_flow(flow_estimates[max_res], up_mask, factor=self.downsampling_factors[0]) - predictions.append(flow_pred) - - return predictions - - -def _crestereo( - *, - weights: Optional[WeightsEnum], - progress: bool, - # Feature Encoder - feature_encoder_layers: Tuple[int, int, int, int, int], - feature_encoder_strides: Tuple[int, int, int, int], - feature_encoder_block: Callable[..., nn.Module], - feature_encoder_norm_layer: Callable[..., nn.Module], - # Average Pooling Pyramid - feature_downsample_rates: Tuple[int, ...], - # Adaptive Correlation Layer - corr_groups: int, - corr_search_window_2d: Tuple[int, int], - corr_search_dilate_2d: Tuple[int, int], - corr_search_window_1d: Tuple[int, int], - corr_search_dilate_1d: Tuple[int, int], - # Flow head - flow_head_hidden_size: int, - # Recurrent block - recurrent_block_hidden_state_size: int, - recurrent_block_kernel_size: Tuple[Tuple[int, int], Tuple[int, int]], - recurrent_block_padding: Tuple[Tuple[int, int], Tuple[int, int]], - # Motion Encoder - motion_encoder_corr_layers: Tuple[int, int], - motion_encoder_flow_layers: Tuple[int, int], - motion_encoder_out_channels: int, - # Transformer Blocks - num_attention_heads: int, - num_self_attention_layers: int, - num_cross_attention_layers: int, - self_attention_module: Callable[..., nn.Module], - cross_attention_module: Callable[..., nn.Module], - **kwargs, -) -> CREStereo: - - feature_encoder = kwargs.pop("feature_encoder", None) or raft.FeatureEncoder( - block=feature_encoder_block, - layers=feature_encoder_layers, - strides=feature_encoder_strides, - norm_layer=feature_encoder_norm_layer, - ) - - if feature_encoder.output_dim % corr_groups != 0: - raise ValueError( - f"Final ``feature_encoder_layers`` size should be divisible by ``corr_groups`` argument." - f"Feature encoder output size : {feature_encoder.output_dim}, Correlation groups: {corr_groups}." - ) - - motion_encoder = kwargs.pop("motion_encoder", None) or raft.MotionEncoder( - in_channels_corr=corr_groups * int(np.prod(corr_search_window_1d)), - corr_layers=motion_encoder_corr_layers, - flow_layers=motion_encoder_flow_layers, - out_channels=motion_encoder_out_channels, - ) - - out_channels_context = feature_encoder_layers[-1] - recurrent_block_hidden_state_size - recurrent_block = kwargs.pop("recurrent_block", None) or raft.RecurrentBlock( - input_size=motion_encoder.out_channels + out_channels_context, - hidden_size=recurrent_block_hidden_state_size, - kernel_size=recurrent_block_kernel_size, - padding=recurrent_block_padding, - ) - - flow_head = kwargs.pop("flow_head", None) or raft.FlowHead( - in_channels=out_channels_context, hidden_size=flow_head_hidden_size - ) - - update_block = raft.UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head) - - self_attention_module = kwargs.pop("self_attention_module", None) or LinearAttention - self_attn_block = LocalFeatureTransformer( - dim_model=feature_encoder.output_dim, - num_heads=num_attention_heads, - attention_directions=["self"] * num_self_attention_layers, - attention_module=self_attention_module, - ) - - cross_attention_module = kwargs.pop("cross_attention_module", None) or LinearAttention - cross_attn_block = LocalFeatureTransformer( - dim_model=feature_encoder.output_dim, - num_heads=num_attention_heads, - attention_directions=["cross"] * num_cross_attention_layers, - attention_module=cross_attention_module, - ) - - model = CREStereo( - feature_encoder=feature_encoder, - update_block=update_block, - flow_head=flow_head, - self_attn_block=self_attn_block, - cross_attn_block=cross_attn_block, - feature_downsample_rates=feature_downsample_rates, - correlation_groups=corr_groups, - search_window_1d=corr_search_window_1d, - search_window_2d=corr_search_window_2d, - search_dilate_1d=corr_search_dilate_1d, - search_dilate_2d=corr_search_dilate_2d, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "resize_size": (384, 512), -} - - -class CREStereo_Base_Weights(WeightsEnum): - """The metrics reported here are as follows. - - ``mae`` is the "mean-average-error" and indicates how far (in pixels) the - predicted disparity is from its true value (equivalent to ``epe``). This is averaged over all pixels - of all images. ``1px``, ``3px``, ``5px`` and indicate the percentage of pixels that have a lower - error than that of the ground truth. ``relepe`` is the "relative-end-point-error" and is the - average ``epe`` divided by the average ground truth disparity. ``fl-all`` corresponds to the average of pixels whose epe - is either <3px, or whom's ``relepe`` is lower than 0.05 (therefore higher is better). - - """ - - MEGVII_V1 = Weights( - # Weights ported from https://github.com/megvii-research/CREStereo - url="https://download.pytorch.org/models/crestereo-756c8b0f.pth", - transforms=StereoMatching, - meta={ - **_COMMON_META, - "num_params": 5432948, - "recipe": "https://github.com/megvii-research/CREStereo", - "_metrics": { - "Middlebury2014-train": { - # metrics for 10 refinement iterations and 1 cascade - "mae": 0.792, - "rmse": 2.765, - "1px": 0.905, - "3px": 0.958, - "5px": 0.97, - "relepe": 0.114, - "fl-all": 90.429, - "_detailed": { - # 1 is the number of cascades - 1: { - # 2 is number of refininement iterations - 2: { - "mae": 1.704, - "rmse": 3.738, - "1px": 0.738, - "3px": 0.896, - "5px": 0.933, - "relepe": 0.157, - "fl-all": 76.464, - }, - 5: { - "mae": 0.956, - "rmse": 2.963, - "1px": 0.88, - "3px": 0.948, - "5px": 0.965, - "relepe": 0.124, - "fl-all": 88.186, - }, - 10: { - "mae": 0.792, - "rmse": 2.765, - "1px": 0.905, - "3px": 0.958, - "5px": 0.97, - "relepe": 0.114, - "fl-all": 90.429, - }, - 20: { - "mae": 0.749, - "rmse": 2.706, - "1px": 0.907, - "3px": 0.961, - "5px": 0.972, - "relepe": 0.113, - "fl-all": 90.807, - }, - }, - 2: { - 2: { - "mae": 1.702, - "rmse": 3.784, - "1px": 0.784, - "3px": 0.894, - "5px": 0.924, - "relepe": 0.172, - "fl-all": 80.313, - }, - 5: { - "mae": 0.932, - "rmse": 2.907, - "1px": 0.877, - "3px": 0.944, - "5px": 0.963, - "relepe": 0.125, - "fl-all": 87.979, - }, - 10: { - "mae": 0.773, - "rmse": 2.768, - "1px": 0.901, - "3px": 0.958, - "5px": 0.972, - "relepe": 0.117, - "fl-all": 90.43, - }, - 20: { - "mae": 0.854, - "rmse": 2.971, - "1px": 0.9, - "3px": 0.957, - "5px": 0.97, - "relepe": 0.122, - "fl-all": 90.269, - }, - }, - }, - } - }, - "_docs": """These weights were ported from the original paper. They - are trained on a dataset mixture of the author's choice.""", - }, - ) - - CRESTEREO_ETH_MBL_V1 = Weights( - # Weights ported from https://github.com/megvii-research/CREStereo - url="https://download.pytorch.org/models/crestereo-8f0e0e9a.pth", - transforms=StereoMatching, - meta={ - **_COMMON_META, - "num_params": 5432948, - "recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo", - "_metrics": { - "Middlebury2014-train": { - # metrics for 10 refinement iterations and 1 cascade - "mae": 1.416, - "rmse": 3.53, - "1px": 0.777, - "3px": 0.896, - "5px": 0.933, - "relepe": 0.148, - "fl-all": 78.388, - "_detailed": { - # 1 is the number of cascades - 1: { - # 2 is the number of refinement iterations - 2: { - "mae": 2.363, - "rmse": 4.352, - "1px": 0.611, - "3px": 0.828, - "5px": 0.891, - "relepe": 0.176, - "fl-all": 64.511, - }, - 5: { - "mae": 1.618, - "rmse": 3.71, - "1px": 0.761, - "3px": 0.879, - "5px": 0.918, - "relepe": 0.154, - "fl-all": 77.128, - }, - 10: { - "mae": 1.416, - "rmse": 3.53, - "1px": 0.777, - "3px": 0.896, - "5px": 0.933, - "relepe": 0.148, - "fl-all": 78.388, - }, - 20: { - "mae": 1.448, - "rmse": 3.583, - "1px": 0.771, - "3px": 0.893, - "5px": 0.931, - "relepe": 0.145, - "fl-all": 77.7, - }, - }, - 2: { - 2: { - "mae": 1.972, - "rmse": 4.125, - "1px": 0.73, - "3px": 0.865, - "5px": 0.908, - "relepe": 0.169, - "fl-all": 74.396, - }, - 5: { - "mae": 1.403, - "rmse": 3.448, - "1px": 0.793, - "3px": 0.905, - "5px": 0.937, - "relepe": 0.151, - "fl-all": 80.186, - }, - 10: { - "mae": 1.312, - "rmse": 3.368, - "1px": 0.799, - "3px": 0.912, - "5px": 0.943, - "relepe": 0.148, - "fl-all": 80.379, - }, - 20: { - "mae": 1.376, - "rmse": 3.542, - "1px": 0.796, - "3px": 0.91, - "5px": 0.942, - "relepe": 0.149, - "fl-all": 80.054, - }, - }, - }, - } - }, - "_docs": """These weights were trained from scratch on - :class:`~torchvision.datasets._stereo_matching.CREStereo` + - :class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` + - :class:`~torchvision.datasets._stereo_matching.ETH3DStereo`.""", - }, - ) - - CRESTEREO_FINETUNE_MULTI_V1 = Weights( - # Weights ported from https://github.com/megvii-research/CREStereo - url="https://download.pytorch.org/models/crestereo-697c38f4.pth ", - transforms=StereoMatching, - meta={ - **_COMMON_META, - "num_params": 5432948, - "recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo", - "_metrics": { - "Middlebury2014-train": { - # metrics for 10 refinement iterations and 1 cascade - "mae": 1.038, - "rmse": 3.108, - "1px": 0.852, - "3px": 0.942, - "5px": 0.963, - "relepe": 0.129, - "fl-all": 85.522, - "_detailed": { - # 1 is the number of cascades - 1: { - # 2 is number of refininement iterations - 2: { - "mae": 1.85, - "rmse": 3.797, - "1px": 0.673, - "3px": 0.862, - "5px": 0.917, - "relepe": 0.171, - "fl-all": 69.736, - }, - 5: { - "mae": 1.111, - "rmse": 3.166, - "1px": 0.838, - "3px": 0.93, - "5px": 0.957, - "relepe": 0.134, - "fl-all": 84.596, - }, - 10: { - "mae": 1.02, - "rmse": 3.073, - "1px": 0.854, - "3px": 0.938, - "5px": 0.96, - "relepe": 0.129, - "fl-all": 86.042, - }, - 20: { - "mae": 0.993, - "rmse": 3.059, - "1px": 0.855, - "3px": 0.942, - "5px": 0.967, - "relepe": 0.126, - "fl-all": 85.784, - }, - }, - 2: { - 2: { - "mae": 1.667, - "rmse": 3.867, - "1px": 0.78, - "3px": 0.891, - "5px": 0.922, - "relepe": 0.165, - "fl-all": 78.89, - }, - 5: { - "mae": 1.158, - "rmse": 3.278, - "1px": 0.843, - "3px": 0.926, - "5px": 0.955, - "relepe": 0.135, - "fl-all": 84.556, - }, - 10: { - "mae": 1.046, - "rmse": 3.13, - "1px": 0.85, - "3px": 0.934, - "5px": 0.96, - "relepe": 0.13, - "fl-all": 85.464, - }, - 20: { - "mae": 1.021, - "rmse": 3.102, - "1px": 0.85, - "3px": 0.935, - "5px": 0.963, - "relepe": 0.129, - "fl-all": 85.417, - }, - }, - }, - }, - }, - "_docs": """These weights were finetuned on a mixture of - :class:`~torchvision.datasets._stereo_matching.CREStereo` + - :class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` + - :class:`~torchvision.datasets._stereo_matching.ETH3DStereo` + - :class:`~torchvision.datasets._stereo_matching.InStereo2k` + - :class:`~torchvision.datasets._stereo_matching.CarlaStereo` + - :class:`~torchvision.datasets._stereo_matching.SintelStereo` + - :class:`~torchvision.datasets._stereo_matching.FallingThingsStereo` + - .""", - }, - ) - - DEFAULT = MEGVII_V1 - - -@register_model() -@handle_legacy_interface(weights=("pretrained", CREStereo_Base_Weights.MEGVII_V1)) -def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress=True, **kwargs) -> CREStereo: - """CREStereo model from - `Practical Stereo Matching via Cascaded Recurrent Network - With Adaptive Correlation <https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf>`_. - - Please see the example below for a tutorial on how to use this model. - - Args: - weights(:class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights` - below for more details, and possible values. By default, no - pre-trained weights are used. - progress (bool): If True, displays a progress bar of the download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo`` - base class. Please refer to the `source code - <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/crestereo.py>`_ - for more details about this class. - - .. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights - :members: - """ - - weights = CREStereo_Base_Weights.verify(weights) - - return _crestereo( - weights=weights, - progress=progress, - # Feature encoder - feature_encoder_layers=(64, 64, 96, 128, 256), - feature_encoder_strides=(2, 1, 2, 1), - feature_encoder_block=partial(raft.ResidualBlock, always_project=True), - feature_encoder_norm_layer=nn.InstanceNorm2d, - # Average pooling pyramid - feature_downsample_rates=(2, 4), - # Motion encoder - motion_encoder_corr_layers=(256, 192), - motion_encoder_flow_layers=(128, 64), - motion_encoder_out_channels=128, - # Recurrent block - recurrent_block_hidden_state_size=128, - recurrent_block_kernel_size=((1, 5), (5, 1)), - recurrent_block_padding=((0, 2), (2, 0)), - # Flow head - flow_head_hidden_size=256, - # Transformer blocks - num_attention_heads=8, - num_self_attention_layers=1, - num_cross_attention_layers=1, - self_attention_module=LinearAttention, - cross_attention_module=LinearAttention, - # Adaptive Correlation layer - corr_groups=4, - corr_search_window_2d=(3, 3), - corr_search_dilate_2d=(1, 1), - corr_search_window_1d=(1, 9), - corr_search_dilate_1d=(1, 1), - ) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py deleted file mode 100644 index aca12948ca0..00000000000 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ /dev/null @@ -1,838 +0,0 @@ -from functools import partial -from typing import Callable, List, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.models.optical_flow.raft as raft -from torch import Tensor -from torchvision.models._api import register_model, Weights, WeightsEnum -from torchvision.models._utils import handle_legacy_interface -from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow -from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock -from torchvision.ops import Conv2dNormActivation -from torchvision.prototype.transforms._presets import StereoMatching -from torchvision.utils import _log_api_usage_once - - -__all__ = ( - "RaftStereo", - "raft_stereo_base", - "raft_stereo_realtime", - "Raft_Stereo_Base_Weights", - "Raft_Stereo_Realtime_Weights", -) - - -class BaseEncoder(raft.FeatureEncoder): - """Base encoder for FeatureEncoder and ContextEncoder in which weight may be shared. - - See the Raft-Stereo paper section 4.6 on backbone part. - """ - - def __init__( - self, - *, - block: Callable[..., nn.Module] = ResidualBlock, - layers: Tuple[int, int, int, int] = (64, 64, 96, 128), - strides: Tuple[int, int, int, int] = (2, 1, 2, 2), - norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d, - ): - # We use layers + (256,) because raft.FeatureEncoder require 5 layers - # but here we will set the last conv layer to identity - super().__init__(block=block, layers=layers + (256,), strides=strides, norm_layer=norm_layer) - - # Base encoder don't have the last conv layer of feature encoder - self.conv = nn.Identity() - - self.output_dim = layers[3] - num_downsampling = sum([x - 1 for x in strides]) - self.downsampling_ratio = 2 ** (num_downsampling) - - -class FeatureEncoder(nn.Module): - """Feature Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Context Encoder. - - The FeatureEncoder takes concatenation of left and right image as input. It produces feature embedding that later - will be used to construct correlation volume. - """ - - def __init__( - self, - base_encoder: BaseEncoder, - output_dim: int = 256, - shared_base: bool = False, - block: Callable[..., nn.Module] = ResidualBlock, - ): - super().__init__() - self.base_encoder = base_encoder - self.base_downsampling_ratio = base_encoder.downsampling_ratio - base_dim = base_encoder.output_dim - - if not shared_base: - self.residual_block: nn.Module = nn.Identity() - self.conv = nn.Conv2d(base_dim, output_dim, kernel_size=1) - else: - # If we share base encoder weight for Feature and Context Encoder - # we need to add residual block with InstanceNorm2d and change the kernel size for conv layer - # see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L35-L37 - self.residual_block = block(base_dim, base_dim, norm_layer=nn.InstanceNorm2d, stride=1) - self.conv = nn.Conv2d(base_dim, output_dim, kernel_size=3, padding=1) - - def forward(self, x: Tensor) -> Tensor: - x = self.base_encoder(x) - x = self.residual_block(x) - x = self.conv(x) - return x - - -class MultiLevelContextEncoder(nn.Module): - """Context Encoder for Raft-Stereo (see paper section 3.1) that may have shared weight with the Feature Encoder. - - The ContextEncoder takes left image as input, and it outputs concatenated hidden_states and contexts. - In Raft-Stereo we have multi level GRUs and this context encoder will also multi outputs (list of Tensor) - that correspond to each GRUs. - Take note that the length of "out_with_blocks" parameter represent the number of GRU's level. - args: - base_encoder (nn.Module): The base encoder part that can have a shared weight with feature_encoder's - base_encoder because they have same architecture. - out_with_blocks (List[bool]): The length represent the number of GRU's level (length of output), and - if the element is True then the output layer on that position will have additional block - output_dim (int): The dimension of output on each level (default: 256) - block (Callable[..., nn.Module]): The type of basic block used for downsampling and output layer - (default: ResidualBlock) - """ - - def __init__( - self, - base_encoder: nn.Module, - out_with_blocks: List[bool], - output_dim: int = 256, - block: Callable[..., nn.Module] = ResidualBlock, - ): - super().__init__() - self.num_level = len(out_with_blocks) - self.base_encoder = base_encoder - self.base_downsampling_ratio = base_encoder.downsampling_ratio - base_dim = base_encoder.output_dim - - self.downsample_and_out_layers = nn.ModuleList( - [ - nn.ModuleDict( - { - "downsampler": self._make_downsampler(block, base_dim, base_dim) if i > 0 else nn.Identity(), - "out_hidden_state": self._make_out_layer( - base_dim, output_dim // 2, with_block=out_with_blocks[i], block=block - ), - "out_context": self._make_out_layer( - base_dim, output_dim // 2, with_block=out_with_blocks[i], block=block - ), - } - ) - for i in range(self.num_level) - ] - ) - - def _make_out_layer(self, in_channels, out_channels, with_block=True, block=ResidualBlock): - layers = [] - if with_block: - layers.append(block(in_channels, in_channels, norm_layer=nn.BatchNorm2d, stride=1)) - layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)) - return nn.Sequential(*layers) - - def _make_downsampler(self, block, in_channels, out_channels): - block1 = block(in_channels, out_channels, norm_layer=nn.BatchNorm2d, stride=2) - block2 = block(out_channels, out_channels, norm_layer=nn.BatchNorm2d, stride=1) - return nn.Sequential(block1, block2) - - def forward(self, x: Tensor) -> List[Tensor]: - x = self.base_encoder(x) - outs = [] - for layer_dict in self.downsample_and_out_layers: - x = layer_dict["downsampler"](x) - outs.append(torch.cat([layer_dict["out_hidden_state"](x), layer_dict["out_context"](x)], dim=1)) - return outs - - -class ConvGRU(raft.ConvGRU): - """Convolutional Gru unit.""" - - # Modified from raft.ConvGRU to accept pre-convolved contexts, - # see: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/update.py#L23 - def forward(self, h: Tensor, x: Tensor, context: List[Tensor]) -> Tensor: # type: ignore[override] - hx = torch.cat([h, x], dim=1) - z = torch.sigmoid(self.convz(hx) + context[0]) - r = torch.sigmoid(self.convr(hx) + context[1]) - q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)) + context[2]) - h = (1 - z) * h + z * q - return h - - -class MultiLevelUpdateBlock(nn.Module): - """The update block which contains the motion encoder and grus - - It must expose a ``hidden_dims`` attribute which is the hidden dimension size of its gru blocks - """ - - def __init__(self, *, motion_encoder: MotionEncoder, hidden_dims: List[int]): - super().__init__() - self.motion_encoder = motion_encoder - - # The GRU input size is the size of previous level hidden_dim plus next level hidden_dim - # if this is the first gru, then we replace previous level with motion_encoder output channels - # for the last GRU, we don't add the next level hidden_dim - gru_input_dims = [] - for i in range(len(hidden_dims)): - input_dim = hidden_dims[i - 1] if i > 0 else motion_encoder.out_channels - if i < len(hidden_dims) - 1: - input_dim += hidden_dims[i + 1] - gru_input_dims.append(input_dim) - - self.grus = nn.ModuleList( - [ - ConvGRU(input_size=gru_input_dims[i], hidden_size=hidden_dims[i], kernel_size=3, padding=1) - # Ideally we should reverse the direction during forward to use the gru with the smallest resolution - # first however currently there is no way to reverse a ModuleList that is jit script compatible - # hence we reverse the ordering of self.grus on the constructor instead - # see: https://github.com/pytorch/pytorch/issues/31772 - for i in reversed(list(range(len(hidden_dims)))) - ] - ) - - self.hidden_dims = hidden_dims - - def forward( - self, - hidden_states: List[Tensor], - contexts: List[List[Tensor]], - corr_features: Tensor, - disparity: Tensor, - level_processed: List[bool], - ) -> List[Tensor]: - # We call it reverse_i because it has a reversed ordering compared to hidden_states - # see self.grus on the constructor for more detail - for reverse_i, gru in enumerate(self.grus): - i = len(self.grus) - 1 - reverse_i - if level_processed[i]: - # X is concatenation of 2x downsampled hidden_dim (or motion_features if no bigger dim) with - # upsampled hidden_dim (or nothing if not exist). - if i == 0: - features = self.motion_encoder(disparity, corr_features) - else: - # 2x downsampled features from larger hidden states - features = F.avg_pool2d(hidden_states[i - 1], kernel_size=3, stride=2, padding=1) - - if i < len(self.grus) - 1: - # Concat with 2x upsampled features from smaller hidden states - _, _, h, w = hidden_states[i + 1].shape - features = torch.cat( - [ - features, - F.interpolate( - hidden_states[i + 1], size=(2 * h, 2 * w), mode="bilinear", align_corners=True - ), - ], - dim=1, - ) - - hidden_states[i] = gru(hidden_states[i], features, contexts[i]) - - # NOTE: For slow-fast gru, we don't always want to calculate delta disparity for every call on UpdateBlock - # Hence we move the delta disparity calculation to the RAFT-Stereo main forward - - return hidden_states - - -class MaskPredictor(raft.MaskPredictor): - """Mask predictor to be used when upsampling the predicted disparity.""" - - # We add out_channels compared to raft.MaskPredictor - def __init__(self, *, in_channels: int, hidden_size: int, out_channels: int, multiplier: float = 0.25): - super(raft.MaskPredictor, self).__init__() - self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) - self.conv = nn.Conv2d(hidden_size, out_channels, kernel_size=1, padding=0) - self.multiplier = multiplier - - -class CorrPyramid1d(nn.Module): - """Row-wise correlation pyramid. - - Create a row-wise correlation pyramid with ``num_levels`` level from the outputs of the feature encoder, - this correlation pyramid will later be used as index to create correlation features using CorrBlock1d. - """ - - def __init__(self, num_levels: int = 4): - super().__init__() - self.num_levels = num_levels - - def forward(self, fmap1: Tensor, fmap2: Tensor) -> List[Tensor]: - """Build the correlation pyramid from two feature maps. - - The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) on the same row. - The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions - to build the correlation pyramid. - """ - - torch._assert( - fmap1.shape == fmap2.shape, - f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)", - ) - - batch_size, num_channels, h, w = fmap1.shape - fmap1 = fmap1.view(batch_size, num_channels, h, w) - fmap2 = fmap2.view(batch_size, num_channels, h, w) - - corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2) - corr = corr.view(batch_size, h, w, 1, w) - corr_volume = corr / torch.sqrt(torch.tensor(num_channels, device=corr.device)) - - corr_volume = corr_volume.reshape(batch_size * h * w, 1, 1, w) - corr_pyramid = [corr_volume] - for _ in range(self.num_levels - 1): - corr_volume = F.avg_pool2d(corr_volume, kernel_size=(1, 2), stride=(1, 2)) - corr_pyramid.append(corr_volume) - - return corr_pyramid - - -class CorrBlock1d(nn.Module): - """The row-wise correlation block. - - Use indexes from correlation pyramid to create correlation features. - The "indexing" of a given centroid pixel x' is done by concatenating its surrounding row neighbours - within radius - """ - - def __init__(self, *, num_levels: int = 4, radius: int = 4): - super().__init__() - self.radius = radius - self.out_channels = num_levels * (2 * radius + 1) - - def forward(self, centroids_coords: Tensor, corr_pyramid: List[Tensor]) -> Tensor: - """Return correlation features by indexing from the pyramid.""" - neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels - di = torch.linspace(-self.radius, self.radius, neighborhood_side_len, device=centroids_coords.device) - di = di.view(1, 1, neighborhood_side_len, 1).to(centroids_coords.device) - - batch_size, _, h, w = centroids_coords.shape # _ = 2 but we only use the first one - # We only consider 1d and take the first dim only - centroids_coords = centroids_coords[:, :1].permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 1) - - indexed_pyramid = [] - for corr_volume in corr_pyramid: - x0 = centroids_coords + di # end shape is (batch_size * h * w, 1, side_len, 1) - y0 = torch.zeros_like(x0) - sampling_coords = torch.cat([x0, y0], dim=-1) - indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view( - batch_size, h, w, -1 - ) - indexed_pyramid.append(indexed_corr_volume) - centroids_coords = centroids_coords / 2 - - corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() - - expected_output_shape = (batch_size, self.out_channels, h, w) - torch._assert( - corr_features.shape == expected_output_shape, - f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}", - ) - return corr_features - - -class RaftStereo(nn.Module): - def __init__( - self, - *, - feature_encoder: FeatureEncoder, - context_encoder: MultiLevelContextEncoder, - corr_pyramid: CorrPyramid1d, - corr_block: CorrBlock1d, - update_block: MultiLevelUpdateBlock, - disparity_head: nn.Module, - mask_predictor: Optional[nn.Module] = None, - slow_fast: bool = False, - ): - """RAFT-Stereo model from - `RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_. - - args: - feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image`` and ``right_image``. - context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image``. - It has multi-level output and each level will have 2 parts: - - - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block`` - - one part will be used to initialize the hidden state of the recurrent unit of - the ``update_block`` - - corr_pyramid (CorrPyramid1d): Module to build the correlation pyramid from feature encoder output - corr_block (CorrBlock1d): The correlation block, which uses the correlation pyramid indexes - to create correlation features. It takes the coordinate of the centroid pixel and correlation pyramid - as input and returns the correlation features. - It must expose an ``out_channels`` attribute. - - update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit. - It takes as input the hidden state of its recurrent unit, the context, the correlation - features, and the current predicted disparity. It outputs an updated hidden state - disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity. - mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow. - If ``None`` (default), the flow is upsampled using interpolation. - slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper - on section 3.4 for more detail. - """ - super().__init__() - _log_api_usage_once(self) - - # This indicates that the disparity output will be only have 1 channel (represent horizontal axis). - # We need this because some stereo matching model like CREStereo might have 2 channel on the output - self.output_channels = 1 - - self.feature_encoder = feature_encoder - self.context_encoder = context_encoder - - self.base_downsampling_ratio = feature_encoder.base_downsampling_ratio - self.num_level = self.context_encoder.num_level - self.corr_pyramid = corr_pyramid - self.corr_block = corr_block - self.update_block = update_block - self.disparity_head = disparity_head - self.mask_predictor = mask_predictor - - hidden_dims = self.update_block.hidden_dims - # Follow the original implementation to do pre convolution on the context - # See: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L32 - self.context_convs = nn.ModuleList( - [nn.Conv2d(hidden_dims[i], hidden_dims[i] * 3, kernel_size=3, padding=1) for i in range(self.num_level)] - ) - self.slow_fast = slow_fast - - def forward( - self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 12 - ) -> List[Tensor]: - """ - Return disparity predictions on every iteration as a list of Tensor. - args: - left_image (Tensor): The input left image with layout B, C, H, W - right_image (Tensor): The input right image with layout B, C, H, W - flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None - num_iters (int): Number of update block iteration on the largest resolution. Default: 12 - """ - batch_size, _, h, w = left_image.shape - torch._assert( - (h, w) == right_image.shape[-2:], - f"input images should have the same shape, instead got ({h}, {w}) != {right_image.shape[-2:]}", - ) - - torch._assert( - (h % self.base_downsampling_ratio == 0 and w % self.base_downsampling_ratio == 0), - f"input image H and W should be divisible by {self.base_downsampling_ratio}, instead got H={h} and W={w}", - ) - - fmaps = self.feature_encoder(torch.cat([left_image, right_image], dim=0)) - fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) - torch._assert( - fmap1.shape[-2:] == (h // self.base_downsampling_ratio, w // self.base_downsampling_ratio), - f"The feature encoder should downsample H and W by {self.base_downsampling_ratio}", - ) - - corr_pyramid = self.corr_pyramid(fmap1, fmap2) - - # Multi level contexts - context_outs = self.context_encoder(left_image) - - hidden_dims = self.update_block.hidden_dims - context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))] - hidden_states: List[Tensor] = [] - contexts: List[List[Tensor]] = [] - for i, context_conv in enumerate(self.context_convs): - # As in the original paper, the actual output of the context encoder is split in 2 parts: - # - one part is used to initialize the hidden state of the recurent units of the update block - # - the rest is the "actual" context. - hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1) - hidden_states.append(torch.tanh(hidden_state)) - contexts.append( - torch.split(context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1) - ) - - _, Cf, Hf, Wf = fmap1.shape - coords0 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device) - coords1 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device) - - # We use flow_init for cascade inference - if flow_init is not None: - coords1 = coords1 + flow_init - - disparity_predictions = [] - for _ in range(num_iters): - coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper - corr_features = self.corr_block(centroids_coords=coords1, corr_pyramid=corr_pyramid) - - disparity = coords1 - coords0 - if self.slow_fast: - # Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often - for i in range(1, self.num_level): - # We only processed the smallest i levels - level_processed = [False] * (self.num_level - i) + [True] * i - hidden_states = self.update_block( - hidden_states, contexts, corr_features, disparity, level_processed=level_processed - ) - hidden_states = self.update_block( - hidden_states, contexts, corr_features, disparity, level_processed=[True] * self.num_level - ) - # Take the largest hidden_state to get the disparity - hidden_state = hidden_states[0] - delta_disparity = self.disparity_head(hidden_state) - # in stereo mode, project disparity onto epipolar - delta_disparity[:, 1] = 0.0 - - coords1 = coords1 + delta_disparity - up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state) - upsampled_disparity = upsample_flow( - (coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio - ) - disparity_predictions.append(upsampled_disparity[:, :1]) - - return disparity_predictions - - -def _raft_stereo( - *, - weights: Optional[WeightsEnum], - progress: bool, - shared_encoder_weight: bool, - # Feature encoder - feature_encoder_layers: Tuple[int, int, int, int, int], - feature_encoder_strides: Tuple[int, int, int, int], - feature_encoder_block: Callable[..., nn.Module], - # Context encoder - context_encoder_layers: Tuple[int, int, int, int, int], - context_encoder_strides: Tuple[int, int, int, int], - # if the `out_with_blocks` param of the context_encoder is True, then - # the particular output on that level position will have additional `context_encoder_block` layer - context_encoder_out_with_blocks: List[bool], - context_encoder_block: Callable[..., nn.Module], - # Correlation block - corr_num_levels: int, - corr_radius: int, - # Motion encoder - motion_encoder_corr_layers: Tuple[int, int], - motion_encoder_flow_layers: Tuple[int, int], - motion_encoder_out_channels: int, - # Update block - update_block_hidden_dims: List[int], - # Flow Head - flow_head_hidden_size: int, - # Mask predictor - mask_predictor_hidden_size: int, - use_mask_predictor: bool, - slow_fast: bool, - **kwargs, -): - if len(context_encoder_out_with_blocks) != len(update_block_hidden_dims): - raise ValueError( - "Length of context_encoder_out_with_blocks and update_block_hidden_dims must be the same" - + "because both of them represent the number of GRUs level" - ) - if shared_encoder_weight: - if ( - feature_encoder_layers[:-1] != context_encoder_layers[:-1] - or feature_encoder_strides != context_encoder_strides - ): - raise ValueError( - "If shared_encoder_weight is True, then the feature_encoder_layers[:-1]" - + " and feature_encoder_strides must be the same with context_encoder_layers[:-1] and context_encoder_strides!" - ) - - base_encoder = kwargs.pop("base_encoder", None) or BaseEncoder( - block=context_encoder_block, - layers=context_encoder_layers[:-1], - strides=context_encoder_strides, - norm_layer=nn.BatchNorm2d, - ) - feature_base_encoder = base_encoder - context_base_encoder = base_encoder - else: - feature_base_encoder = BaseEncoder( - block=feature_encoder_block, - layers=feature_encoder_layers[:-1], - strides=feature_encoder_strides, - norm_layer=nn.InstanceNorm2d, - ) - context_base_encoder = BaseEncoder( - block=context_encoder_block, - layers=context_encoder_layers[:-1], - strides=context_encoder_strides, - norm_layer=nn.BatchNorm2d, - ) - feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder( - feature_base_encoder, - output_dim=feature_encoder_layers[-1], - shared_base=shared_encoder_weight, - block=feature_encoder_block, - ) - context_encoder = kwargs.pop("context_encoder", None) or MultiLevelContextEncoder( - context_base_encoder, - out_with_blocks=context_encoder_out_with_blocks, - output_dim=context_encoder_layers[-1], - block=context_encoder_block, - ) - - feature_downsampling_ratio = feature_encoder.base_downsampling_ratio - - corr_pyramid = kwargs.pop("corr_pyramid", None) or CorrPyramid1d(num_levels=corr_num_levels) - corr_block = kwargs.pop("corr_block", None) or CorrBlock1d(num_levels=corr_num_levels, radius=corr_radius) - - motion_encoder = kwargs.pop("motion_encoder", None) or MotionEncoder( - in_channels_corr=corr_block.out_channels, - corr_layers=motion_encoder_corr_layers, - flow_layers=motion_encoder_flow_layers, - out_channels=motion_encoder_out_channels, - ) - update_block = kwargs.pop("update_block", None) or MultiLevelUpdateBlock( - motion_encoder=motion_encoder, hidden_dims=update_block_hidden_dims - ) - - # We use the largest scale hidden_dims of update_block to get the predicted disparity - disparity_head = kwargs.pop("disparity_head", None) or FlowHead( - in_channels=update_block_hidden_dims[0], - hidden_size=flow_head_hidden_size, - ) - - mask_predictor = kwargs.pop("mask_predictor", None) - if use_mask_predictor: - mask_predictor = MaskPredictor( - in_channels=update_block.hidden_dims[0], - hidden_size=mask_predictor_hidden_size, - out_channels=9 * feature_downsampling_ratio * feature_downsampling_ratio, - ) - else: - mask_predictor = None - - model = RaftStereo( - feature_encoder=feature_encoder, - context_encoder=context_encoder, - corr_pyramid=corr_pyramid, - corr_block=corr_block, - update_block=update_block, - disparity_head=disparity_head, - mask_predictor=mask_predictor, - slow_fast=slow_fast, - **kwargs, # not really needed, all params should be consumed by now - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -class Raft_Stereo_Realtime_Weights(WeightsEnum): - SCENEFLOW_V1 = Weights( - # Weights ported from https://github.com/princeton-vl/RAFT-Stereo - url="https://download.pytorch.org/models/raft_stereo_realtime-cf345ccb.pth", - transforms=partial(StereoMatching, resize_size=(224, 224)), - meta={ - "num_params": 8077152, - "recipe": "https://github.com/princeton-vl/RAFT-Stereo", - "_metrics": { - # Following metrics from paper: https://arxiv.org/abs/2109.07547 - "Kitty2015": { - "3px": 0.9409, - } - }, - }, - ) - - DEFAULT = SCENEFLOW_V1 - - -class Raft_Stereo_Base_Weights(WeightsEnum): - SCENEFLOW_V1 = Weights( - # Weights ported from https://github.com/princeton-vl/RAFT-Stereo - url="https://download.pytorch.org/models/raft_stereo_base_sceneflow-eff3f2e6.pth", - transforms=partial(StereoMatching, resize_size=(224, 224)), - meta={ - "num_params": 11116176, - "recipe": "https://github.com/princeton-vl/RAFT-Stereo", - "_metrics": { - # Following metrics from paper: https://arxiv.org/abs/2109.07547 - # Using standard metrics for each dataset - "Kitty2015": { - # Ratio of pixels with difference less than 3px from ground truth - "3px": 0.9426, - }, - # For middlebury, ratio of pixels with difference less than 2px from ground truth - # on full, half, and quarter image resolution - "Middlebury2014-val-full": { - "2px": 0.8167, - }, - "Middlebury2014-val-half": { - "2px": 0.8741, - }, - "Middlebury2014-val-quarter": { - "2px": 0.9064, - }, - "ETH3D-val": { - # Ratio of pixels with difference less than 1px from ground truth - "1px": 0.9672, - }, - }, - }, - ) - - MIDDLEBURY_V1 = Weights( - # Weights ported from https://github.com/princeton-vl/RAFT-Stereo - url="https://download.pytorch.org/models/raft_stereo_base_middlebury-afa9d252.pth", - transforms=partial(StereoMatching, resize_size=(224, 224)), - meta={ - "num_params": 11116176, - "recipe": "https://github.com/princeton-vl/RAFT-Stereo", - "_metrics": { - # Following metrics from paper: https://arxiv.org/abs/2109.07547 - "Middlebury-test": { - "mae": 1.27, - "1px": 0.9063, - "2px": 0.9526, - "5px": 0.9725, - } - }, - }, - ) - - ETH3D_V1 = Weights( - # Weights ported from https://github.com/princeton-vl/RAFT-Stereo - url="https://download.pytorch.org/models/raft_stereo_base_eth3d-d4830f22.pth", - transforms=partial(StereoMatching, resize_size=(224, 224)), - meta={ - "num_params": 11116176, - "recipe": "https://github.com/princeton-vl/RAFT-Stereo", - "_metrics": { - # Following metrics from paper: https://arxiv.org/abs/2109.07547 - "ETH3D-test": { - "mae": 0.18, - "1px": 0.9756, - "2px": 0.9956, - } - }, - }, - ) - - DEFAULT = MIDDLEBURY_V1 - - -@register_model() -@handle_legacy_interface(weights=("pretrained", None)) -def raft_stereo_realtime( - *, weights: Optional[Raft_Stereo_Realtime_Weights] = None, progress=True, **kwargs -) -> RaftStereo: - """RAFT-Stereo model from - `RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_. - This is the realtime variant of the Raft-Stereo model that is described on the paper section 4.7. - - Please see the example below for a tutorial on how to use this model. - - Args: - weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights` - below for more details, and possible values. By default, no - pre-trained weights are used. - progress (bool): If True, displays a progress bar of the download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo`` - base class. Please refer to the `source code - <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_ - for more details about this class. - - .. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Realtime_Weights - :members: - """ - - weights = Raft_Stereo_Realtime_Weights.verify(weights) - - return _raft_stereo( - weights=weights, - progress=progress, - shared_encoder_weight=True, - # Feature encoder - feature_encoder_layers=(64, 64, 96, 128, 256), - feature_encoder_strides=(2, 1, 2, 2), - feature_encoder_block=ResidualBlock, - # Context encoder - context_encoder_layers=(64, 64, 96, 128, 256), - context_encoder_strides=(2, 1, 2, 2), - context_encoder_out_with_blocks=[True, True], - context_encoder_block=ResidualBlock, - # Correlation block - corr_num_levels=4, - corr_radius=4, - # Motion encoder - motion_encoder_corr_layers=(64, 64), - motion_encoder_flow_layers=(64, 64), - motion_encoder_out_channels=128, - # Update block - update_block_hidden_dims=[128, 128], - # Flow head - flow_head_hidden_size=256, - # Mask predictor - mask_predictor_hidden_size=256, - use_mask_predictor=True, - slow_fast=True, - **kwargs, - ) - - -@register_model() -@handle_legacy_interface(weights=("pretrained", None)) -def raft_stereo_base(*, weights: Optional[Raft_Stereo_Base_Weights] = None, progress=True, **kwargs) -> RaftStereo: - """RAFT-Stereo model from - `RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_. - - Please see the example below for a tutorial on how to use this model. - - Args: - weights(:class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights` - below for more details, and possible values. By default, no - pre-trained weights are used. - progress (bool): If True, displays a progress bar of the download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo`` - base class. Please refer to the `source code - <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_ - for more details about this class. - - .. autoclass:: torchvision.prototype.models.depth.stereo.Raft_Stereo_Base_Weights - :members: - """ - - weights = Raft_Stereo_Base_Weights.verify(weights) - - return _raft_stereo( - weights=weights, - progress=progress, - shared_encoder_weight=False, - # Feature encoder - feature_encoder_layers=(64, 64, 96, 128, 256), - feature_encoder_strides=(1, 1, 2, 2), - feature_encoder_block=ResidualBlock, - # Context encoder - context_encoder_layers=(64, 64, 96, 128, 256), - context_encoder_strides=(1, 1, 2, 2), - context_encoder_out_with_blocks=[True, True, False], - context_encoder_block=ResidualBlock, - # Correlation block - corr_num_levels=4, - corr_radius=4, - # Motion encoder - motion_encoder_corr_layers=(64, 64), - motion_encoder_flow_layers=(64, 64), - motion_encoder_out_channels=128, - # Update block - update_block_hidden_dims=[128, 128, 128], - # Flow head - flow_head_hidden_size=256, - # Mask predictor - mask_predictor_hidden_size=256, - use_mask_predictor=True, - slow_fast=False, - **kwargs, - ) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py deleted file mode 100644 index 4f8fdef484c..00000000000 --- a/torchvision/prototype/transforms/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from ._presets import StereoMatching # usort: skip - -from ._augment import RandomCutmix, RandomMixup, SimpleCopyPaste -from ._geometry import FixedSizeCrop -from ._misc import PermuteDimensions, TransposeDimensions -from ._type_conversion import LabelToOneHot diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py deleted file mode 100644 index d04baf739d1..00000000000 --- a/torchvision/prototype/transforms/_augment.py +++ /dev/null @@ -1,300 +0,0 @@ -import math -from typing import Any, cast, Dict, List, Optional, Tuple, Union - -import PIL.Image -import torch -from torch.utils._pytree import tree_flatten, tree_unflatten -from torchvision import datapoints -from torchvision.ops import masks_to_boxes -from torchvision.prototype import datapoints as proto_datapoints -from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform - -from torchvision.transforms.v2._transform import _RandomApplyTransform -from torchvision.transforms.v2.functional._geometry import _check_interpolation -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_spatial_size - - -class _BaseMixupCutmix(_RandomApplyTransform): - def __init__(self, alpha: float, p: float = 0.5) -> None: - super().__init__(p=p) - self.alpha = alpha - self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) - - def _check_inputs(self, flat_inputs: List[Any]) -> None: - if not ( - has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor) - and has_any(flat_inputs, proto_datapoints.OneHotLabel) - ): - raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") - if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, proto_datapoints.Label): - raise TypeError( - f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." - ) - - def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel: - if inpt.ndim < 2: - raise ValueError("Need a batch of one hot labels") - output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) - return proto_datapoints.OneHotLabel.wrap_like(inpt, output) - - -class RandomMixup(_BaseMixupCutmix): - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - lam = params["lam"] - if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): - expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 - if inpt.ndim < expected_ndim: - raise ValueError("The transform expects a batched input") - output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] - - return output - elif isinstance(inpt, proto_datapoints.OneHotLabel): - return self._mixup_onehotlabel(inpt, lam) - else: - return inpt - - -class RandomCutmix(_BaseMixupCutmix): - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - lam = float(self._dist.sample(())) # type: ignore[arg-type] - - H, W = query_spatial_size(flat_inputs) - - r_x = torch.randint(W, ()) - r_y = torch.randint(H, ()) - - r = 0.5 * math.sqrt(1.0 - lam) - r_w_half = int(r * W) - r_h_half = int(r * H) - - x1 = int(torch.clamp(r_x - r_w_half, min=0)) - y1 = int(torch.clamp(r_y - r_h_half, min=0)) - x2 = int(torch.clamp(r_x + r_w_half, max=W)) - y2 = int(torch.clamp(r_y + r_h_half, max=H)) - box = (x1, y1, x2, y2) - - lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) - - return dict(box=box, lam_adjusted=lam_adjusted) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): - box = params["box"] - expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 - if inpt.ndim < expected_ndim: - raise ValueError("The transform expects a batched input") - x1, y1, x2, y2 = box - rolled = inpt.roll(1, 0) - output = inpt.clone() - output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - - return output - elif isinstance(inpt, proto_datapoints.OneHotLabel): - lam_adjusted = params["lam_adjusted"] - return self._mixup_onehotlabel(inpt, lam_adjusted) - else: - return inpt - - -class SimpleCopyPaste(Transform): - def __init__( - self, - blending: bool = True, - resize_interpolation: Union[int, InterpolationMode] = F.InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, - ) -> None: - super().__init__() - self.resize_interpolation = _check_interpolation(resize_interpolation) - self.blending = blending - self.antialias = antialias - - def _copy_paste( - self, - image: datapoints._TensorImageType, - target: Dict[str, Any], - paste_image: datapoints._TensorImageType, - paste_target: Dict[str, Any], - random_selection: torch.Tensor, - blending: bool, - resize_interpolation: F.InterpolationMode, - antialias: Optional[bool], - ) -> Tuple[datapoints._TensorImageType, Dict[str, Any]]: - - paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) - paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) - paste_labels = paste_target["labels"].wrap_like( - paste_target["labels"], paste_target["labels"][random_selection] - ) - - masks = target["masks"] - - # We resize source and paste data if they have different sizes - # This is something different to TF implementation we introduced here as - # originally the algorithm works on equal-sized data - # (for example, coming from LSJ data augmentations) - size1 = cast(List[int], image.shape[-2:]) - size2 = paste_image.shape[-2:] - if size1 != size2: - paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias) - paste_masks = F.resize(paste_masks, size=size1) - paste_boxes = F.resize(paste_boxes, size=size1) - - paste_alpha_mask = paste_masks.sum(dim=0) > 0 - - if blending: - paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) - - inverse_paste_alpha_mask = paste_alpha_mask.logical_not() - # Copy-paste images: - image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask)) - - # Copy-paste masks: - masks = masks * inverse_paste_alpha_mask - non_all_zero_masks = masks.sum((-1, -2)) > 0 - masks = masks[non_all_zero_masks] - - # Do a shallow copy of the target dict - out_target = {k: v for k, v in target.items()} - - out_target["masks"] = torch.cat([masks, paste_masks]) - - # Copy-paste boxes and labels - bbox_format = target["boxes"].format - xyxy_boxes = masks_to_boxes(masks) - # masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive - # we need to add +1 to x2y2. - # There is a similar +1 in other reference implementations: - # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 - xyxy_boxes[:, 2:] += 1 - boxes = F.convert_format_bounding_box( - xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True - ) - out_target["boxes"] = torch.cat([boxes, paste_boxes]) - - labels = target["labels"][non_all_zero_masks] - out_target["labels"] = torch.cat([labels, paste_labels]) - - # Check for degenerated boxes and remove them - boxes = F.convert_format_bounding_box( - out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY - ) - degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] - if degenerate_boxes.any(): - valid_targets = ~degenerate_boxes.any(dim=1) - - out_target["boxes"] = boxes[valid_targets] - out_target["masks"] = out_target["masks"][valid_targets] - out_target["labels"] = out_target["labels"][valid_targets] - - return image, out_target - - def _extract_image_targets( - self, flat_sample: List[Any] - ) -> Tuple[List[datapoints._TensorImageType], List[Dict[str, Any]]]: - # fetch all images, bboxes, masks and labels from unstructured input - # with List[image], List[BoundingBox], List[Mask], List[Label] - images, bboxes, masks, labels = [], [], [], [] - for obj in flat_sample: - if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): - images.append(obj) - elif isinstance(obj, PIL.Image.Image): - images.append(F.to_image_tensor(obj)) - elif isinstance(obj, datapoints.BoundingBox): - bboxes.append(obj) - elif isinstance(obj, datapoints.Mask): - masks.append(obj) - elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)): - labels.append(obj) - - if not (len(images) == len(bboxes) == len(masks) == len(labels)): - raise TypeError( - f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " - "BoundingBoxes, Masks and Labels or OneHotLabels." - ) - - targets = [] - for bbox, mask, label in zip(bboxes, masks, labels): - targets.append({"boxes": bbox, "masks": mask, "labels": label}) - - return images, targets - - def _insert_outputs( - self, - flat_sample: List[Any], - output_images: List[datapoints._TensorImageType], - output_targets: List[Dict[str, Any]], - ) -> None: - c0, c1, c2, c3 = 0, 0, 0, 0 - for i, obj in enumerate(flat_sample): - if isinstance(obj, datapoints.Image): - flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0]) - c0 += 1 - elif isinstance(obj, PIL.Image.Image): - flat_sample[i] = F.to_image_pil(output_images[c0]) - c0 += 1 - elif is_simple_tensor(obj): - flat_sample[i] = output_images[c0] - c0 += 1 - elif isinstance(obj, datapoints.BoundingBox): - flat_sample[i] = datapoints.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"]) - c1 += 1 - elif isinstance(obj, datapoints.Mask): - flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) - c2 += 1 - elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)): - flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] - c3 += 1 - - def forward(self, *inputs: Any) -> Any: - flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) - - images, targets = self._extract_image_targets(flat_inputs) - - # images = [t1, t2, ..., tN] - # Let's define paste_images as shifted list of input images - # paste_images = [t2, t3, ..., tN, t1] - # FYI: in TF they mix data on the dataset level - images_rolled = images[-1:] + images[:-1] - targets_rolled = targets[-1:] + targets[:-1] - - output_images, output_targets = [], [] - - for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): - - # Random paste targets selection: - num_masks = len(paste_target["masks"]) - - if num_masks < 1: - # Such degerante case with num_masks=0 can happen with LSJ - # Let's just return (image, target) - output_image, output_target = image, target - else: - random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) - random_selection = torch.unique(random_selection) - - output_image, output_target = self._copy_paste( - image, - target, - paste_image, - paste_target, - random_selection=random_selection, - blending=self.blending, - resize_interpolation=self.resize_interpolation, - antialias=self.antialias, - ) - output_images.append(output_image) - output_targets.append(output_target) - - # Insert updated images and targets into input flat_sample - self._insert_outputs(flat_inputs, output_images, output_targets) - - return tree_unflatten(flat_inputs, spec) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py deleted file mode 100644 index 8d5cc24d25a..00000000000 --- a/torchvision/prototype/transforms/_geometry.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import Any, Dict, List, Optional, Sequence, Type, Union - -import PIL.Image -import torch - -from torchvision import datapoints -from torchvision.prototype.datapoints import Label, OneHotLabel -from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_box, query_spatial_size - - -class FixedSizeCrop(Transform): - def __init__( - self, - size: Union[int, Sequence[int]], - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, - padding_mode: str = "constant", - ) -> None: - super().__init__() - size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) - self.crop_height = size[0] - self.crop_width = size[1] - - self.fill = fill - self._fill = _setup_fill_arg(fill) - - self.padding_mode = padding_mode - - def _check_inputs(self, flat_inputs: List[Any]) -> None: - if not has_any( - flat_inputs, - PIL.Image.Image, - datapoints.Image, - is_simple_tensor, - datapoints.Video, - ): - raise TypeError( - f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video." - ) - - if has_any(flat_inputs, datapoints.BoundingBox) and not has_any(flat_inputs, Label, OneHotLabel): - raise TypeError( - f"If a BoundingBox is contained in the input sample, " - f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." - ) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - height, width = query_spatial_size(flat_inputs) - new_height = min(height, self.crop_height) - new_width = min(width, self.crop_width) - - needs_crop = new_height != height or new_width != width - - offset_height = max(height - self.crop_height, 0) - offset_width = max(width - self.crop_width, 0) - - r = torch.rand(1) - top = int(offset_height * r) - left = int(offset_width * r) - - bounding_boxes: Optional[torch.Tensor] - try: - bounding_boxes = query_bounding_box(flat_inputs) - except ValueError: - bounding_boxes = None - - if needs_crop and bounding_boxes is not None: - format = bounding_boxes.format - bounding_boxes, spatial_size = F.crop_bounding_box( - bounding_boxes.as_subclass(torch.Tensor), - format=format, - top=top, - left=left, - height=new_height, - width=new_width, - ) - bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size) - height_and_width = F.convert_format_bounding_box( - bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH - )[..., 2:] - is_valid = torch.all(height_and_width > 0, dim=-1) - else: - is_valid = None - - pad_bottom = max(self.crop_height - new_height, 0) - pad_right = max(self.crop_width - new_width, 0) - - needs_pad = pad_bottom != 0 or pad_right != 0 - - return dict( - needs_crop=needs_crop, - top=top, - left=left, - height=new_height, - width=new_width, - is_valid=is_valid, - padding=[0, 0, pad_right, pad_bottom], - needs_pad=needs_pad, - ) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if params["needs_crop"]: - inpt = F.crop( - inpt, - top=params["top"], - left=params["left"], - height=params["height"], - width=params["width"], - ) - - if params["is_valid"] is not None: - if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)): - inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] - elif isinstance(inpt, datapoints.BoundingBox): - inpt = datapoints.BoundingBox.wrap_like( - inpt, - F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size), - ) - - if params["needs_pad"]: - fill = self._fill[type(inpt)] - inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) - - return inpt diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py deleted file mode 100644 index 3a4e6e956f3..00000000000 --- a/torchvision/prototype/transforms/_misc.py +++ /dev/null @@ -1,58 +0,0 @@ -import warnings -from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union - -import torch - -from torchvision import datapoints -from torchvision.transforms.v2 import Transform - -from torchvision.transforms.v2._utils import _get_defaultdict -from torchvision.transforms.v2.utils import is_simple_tensor - - -class PermuteDimensions(Transform): - _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) - - def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: - super().__init__() - if not isinstance(dims, dict): - dims = _get_defaultdict(dims) - if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): - warnings.warn( - "Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " - "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " - "in case a `datapoints.Image` or `datapoints.Video` is present in the input." - ) - self.dims = dims - - def _transform( - self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] - ) -> torch.Tensor: - dims = self.dims[type(inpt)] - if dims is None: - return inpt.as_subclass(torch.Tensor) - return inpt.permute(*dims) - - -class TransposeDimensions(Transform): - _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) - - def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: - super().__init__() - if not isinstance(dims, dict): - dims = _get_defaultdict(dims) - if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): - warnings.warn( - "Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " - "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " - "in case a `datapoints.Image` or `datapoints.Video` is present in the input." - ) - self.dims = dims - - def _transform( - self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] - ) -> torch.Tensor: - dims = self.dims[type(inpt)] - if dims is None: - return inpt.as_subclass(torch.Tensor) - return inpt.transpose(*dims) diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py deleted file mode 100644 index 25c39a90382..00000000000 --- a/torchvision/prototype/transforms/_presets.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -This file is part of the private API. Please do not use directly these classes as they will be modified on -future versions without warning. The classes should be accessed only via the transforms argument of Weights. -""" -from typing import List, Optional, Tuple, Union - -import PIL.Image - -import torch -from torch import Tensor - -from torchvision.transforms.v2 import functional as F, InterpolationMode - -from torchvision.transforms.v2.functional._geometry import _check_interpolation - -__all__ = ["StereoMatching"] - - -class StereoMatching(torch.nn.Module): - def __init__( - self, - *, - use_gray_scale: bool = False, - resize_size: Optional[Tuple[int, ...]], - mean: Tuple[float, ...] = (0.5, 0.5, 0.5), - std: Tuple[float, ...] = (0.5, 0.5, 0.5), - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - ) -> None: - super().__init__() - - # pacify mypy - self.resize_size: Union[None, List] - - if resize_size is not None: - self.resize_size = list(resize_size) - else: - self.resize_size = None - - self.mean = list(mean) - self.std = list(std) - self.interpolation = _check_interpolation(interpolation) - self.use_gray_scale = use_gray_scale - - def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]: - def _process_image(img: PIL.Image.Image) -> Tensor: - if not isinstance(img, Tensor): - img = F.pil_to_tensor(img) - if self.resize_size is not None: - # We hard-code antialias=False to preserve results after we changed - # its default from None to True (see - # https://github.com/pytorch/vision/pull/7160) - # TODO: we could re-train the stereo models with antialias=True? - img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=False) - if self.use_gray_scale is True: - img = F.rgb_to_grayscale(img) - img = F.convert_image_dtype(img, torch.float) - img = F.normalize(img, mean=self.mean, std=self.std) - img = img.contiguous() - return img - - left_image = _process_image(left_image) - right_image = _process_image(right_image) - return left_image, right_image - - def __repr__(self) -> str: - format_string = self.__class__.__name__ + "(" - format_string += f"\n resize_size={self.resize_size}" - format_string += f"\n mean={self.mean}" - format_string += f"\n std={self.std}" - format_string += f"\n interpolation={self.interpolation}" - format_string += "\n)" - return format_string - - def describe(self) -> str: - return ( - "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " - f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. " - f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and " - f"``std={self.std}``." - ) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py deleted file mode 100644 index 4cd3cf46871..00000000000 --- a/torchvision/prototype/transforms/_type_conversion.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Any, Dict - -import torch - -from torch.nn.functional import one_hot - -from torchvision.prototype import datapoints as proto_datapoints -from torchvision.transforms.v2 import Transform - - -class LabelToOneHot(Transform): - _transformed_types = (proto_datapoints.Label,) - - def __init__(self, num_categories: int = -1): - super().__init__() - self.num_categories = num_categories - - def _transform(self, inpt: proto_datapoints.Label, params: Dict[str, Any]) -> proto_datapoints.OneHotLabel: - num_categories = self.num_categories - if num_categories == -1 and inpt.categories is not None: - num_categories = len(inpt.categories) - output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) - return proto_datapoints.OneHotLabel(output, categories=inpt.categories) - - def extra_repr(self) -> str: - if self.num_categories == -1: - return "" - - return f"num_categories={self.num_categories}" diff --git a/torchvision/prototype/utils/__init__.py b/torchvision/prototype/utils/__init__.py deleted file mode 100644 index e85a582b483..00000000000 --- a/torchvision/prototype/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import _internal diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py deleted file mode 100644 index 3dee4b59a7a..00000000000 --- a/torchvision/prototype/utils/_internal.py +++ /dev/null @@ -1,126 +0,0 @@ -import collections.abc -import difflib -import io -import mmap -import platform -from typing import BinaryIO, Callable, Collection, Sequence, TypeVar, Union - -import numpy as np -import torch -from torchvision._utils import sequence_to_str - - -__all__ = [ - "add_suggestion", - "fromfile", - "ReadOnlyTensorBuffer", -] - - -def add_suggestion( - msg: str, - *, - word: str, - possibilities: Collection[str], - close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?", - alternative_hint: Callable[ - [Sequence[str]], str - ] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.", -) -> str: - if not isinstance(possibilities, collections.abc.Sequence): - possibilities = sorted(possibilities) - suggestions = difflib.get_close_matches(word, possibilities, 1) - hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities) - if not hint: - return msg - - return f"{msg.strip()} {hint}" - - -D = TypeVar("D") - - -def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: - # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable - return bytearray(file.read(-1 if count == -1 else count * item_size)) - - -def fromfile( - file: BinaryIO, - *, - dtype: torch.dtype, - byte_order: str, - count: int = -1, -) -> torch.Tensor: - """Construct a tensor from a binary file. - .. note:: - This function is similar to :func:`numpy.fromfile` with two notable differences: - 1. This function only accepts an open binary file, but not a path to it. - 2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that - concept. - .. note:: - If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as - long as the file is still open, inplace operations on the returned tensor will reflect back to the file. - Args: - file (IO): Open binary file. - dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor. - byte_order (str): Byte order of the data. Can be "little" or "big" endian. - count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file. - """ - byte_order = "<" if byte_order == "little" else ">" - char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u") - item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 - np_dtype = byte_order + char + str(item_size) - - buffer: Union[memoryview, bytearray] - if platform.system() != "Windows": - # PyTorch does not support tensors with underlying read-only memory. In case - # - the file has a .fileno(), - # - the file was opened for updating, i.e. 'r+b' or 'w+b', - # - the file is seekable - # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it - # to a mutable location afterwards. - try: - buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] - # Reading from the memoryview does not advance the file cursor, so we have to do it manually. - file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) - except (AttributeError, PermissionError, io.UnsupportedOperation): - buffer = _read_mutable_buffer_fallback(file, count, item_size) - else: - # On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state - # so no data can be read afterwards. Thus, we simply ignore the possible speed-up. - buffer = _read_mutable_buffer_fallback(file, count, item_size) - - # We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we - # read the data with np.frombuffer() with the correct byte order and convert it to the native one with the - # successive .astype() call. - return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False)) - - -class ReadOnlyTensorBuffer: - def __init__(self, tensor: torch.Tensor) -> None: - self._memory = memoryview(tensor.numpy()) - self._cursor: int = 0 - - def tell(self) -> int: - return self._cursor - - def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: - if whence == io.SEEK_SET: - self._cursor = offset - elif whence == io.SEEK_CUR: - self._cursor += offset - pass - elif whence == io.SEEK_END: - self._cursor = len(self._memory) + offset - else: - raise ValueError( - f"'whence' should be ``{io.SEEK_SET}``, ``{io.SEEK_CUR}``, or ``{io.SEEK_END}``, " - f"but got {repr(whence)} instead" - ) - return self.tell() - - def read(self, size: int = -1) -> bytes: - cursor = self.tell() - offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR) - return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()