Skip to content

Commit 8cae313

Browse files
authored
Merge branch 'main' into models/enhanced_meta
2 parents 93d9fc4 + 5dc61cb commit 8cae313

File tree

9 files changed

+404
-4
lines changed

9 files changed

+404
-4
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4848
FlyingChairs
4949
FlyingThings3D
5050
Food101
51+
GTSRB
5152
HD1K
5253
HMDB51
5354
ImageNet

test/test_datasets.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,5 +2275,87 @@ def inject_fake_data(self, tmpdir, config):
22752275
return num_samples
22762276

22772277

2278+
class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
2279+
DATASET_CLASS = datasets.GTSRB
2280+
FEATURE_TYPES = (PIL.Image.Image, int)
2281+
2282+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
2283+
2284+
def inject_fake_data(self, tmpdir: str, config):
2285+
root_folder = os.path.join(tmpdir, "GTSRB")
2286+
os.makedirs(root_folder, exist_ok=True)
2287+
2288+
# Train data
2289+
train_folder = os.path.join(root_folder, "Training")
2290+
os.makedirs(train_folder, exist_ok=True)
2291+
2292+
num_examples = 3
2293+
classes = ("00000", "00042", "00012")
2294+
for class_idx in classes:
2295+
datasets_utils.create_image_folder(
2296+
train_folder,
2297+
name=class_idx,
2298+
file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm",
2299+
num_examples=num_examples,
2300+
)
2301+
2302+
total_number_of_examples = num_examples * len(classes)
2303+
# Test data
2304+
test_folder = os.path.join(root_folder, "Final_Test", "Images")
2305+
os.makedirs(test_folder, exist_ok=True)
2306+
2307+
with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:
2308+
csv_file.write("Filename;Width;Height;Roi.X1;Roi.Y1;Roi.X2;Roi.Y2;ClassId\n")
2309+
2310+
for _ in range(total_number_of_examples):
2311+
image_file = datasets_utils.create_random_string(5, string.digits) + ".ppm"
2312+
datasets_utils.create_image_file(test_folder, image_file)
2313+
row = [
2314+
image_file,
2315+
torch.randint(1, 100, size=()).item(),
2316+
torch.randint(1, 100, size=()).item(),
2317+
torch.randint(1, 100, size=()).item(),
2318+
torch.randint(1, 100, size=()).item(),
2319+
torch.randint(1, 100, size=()).item(),
2320+
torch.randint(1, 100, size=()).item(),
2321+
torch.randint(0, 43, size=()).item(),
2322+
]
2323+
csv_file.write(";".join(map(str, row)) + "\n")
2324+
2325+
return total_number_of_examples
2326+
2327+
2328+
class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
2329+
DATASET_CLASS = datasets.CLEVRClassification
2330+
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
2331+
2332+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
2333+
2334+
def inject_fake_data(self, tmpdir, config):
2335+
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
2336+
2337+
images_folder = data_folder / "images"
2338+
image_files = datasets_utils.create_image_folder(
2339+
images_folder, config["split"], lambda idx: f"CLEVR_{config['split']}_{idx:06d}.png", num_examples=5
2340+
)
2341+
2342+
scenes_folder = data_folder / "scenes"
2343+
scenes_folder.mkdir()
2344+
if config["split"] != "test":
2345+
with open(scenes_folder / f"CLEVR_{config['split']}_scenes.json", "w") as file:
2346+
json.dump(
2347+
dict(
2348+
info=dict(),
2349+
scenes=[
2350+
dict(image_filename=image_file.name, objects=[dict()] * int(torch.randint(10, ())))
2351+
for image_file in image_files
2352+
],
2353+
),
2354+
file,
2355+
)
2356+
2357+
return len(image_files)
2358+
2359+
22782360
if __name__ == "__main__":
22792361
unittest.main()

test/test_image.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,8 @@ def test_write_jpeg_reference(img_path, tmpdir):
478478
assert_equal(torch_bytes, pil_bytes)
479479

480480

481-
@pytest.mark.skipif(IS_WINDOWS, reason=("this test fails on windows because PIL uses libjpeg-turbo on windows"))
481+
# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
482+
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
482483
@pytest.mark.parametrize(
483484
"img_path",
484485
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
@@ -497,7 +498,8 @@ def test_encode_jpeg(img_path):
497498
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
498499

499500

500-
@pytest.mark.skipif(IS_WINDOWS, reason=("this test fails on windows because PIL uses libjpeg-turbo on windows"))
501+
# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
502+
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
501503
@pytest.mark.parametrize(
502504
"img_path",
503505
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],

torchvision/datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from .celeba import CelebA
44
from .cifar import CIFAR10, CIFAR100
55
from .cityscapes import Cityscapes
6+
from .clevr import CLEVRClassification
67
from .coco import CocoCaptions, CocoDetection
78
from .dtd import DTD
89
from .fakedata import FakeData
910
from .fer2013 import FER2013
1011
from .flickr import Flickr8k, Flickr30k
1112
from .folder import ImageFolder, DatasetFolder
1213
from .food101 import Food101
14+
from .gtsrb import GTSRB
1315
from .hmdb51 import HMDB51
1416
from .imagenet import ImageNet
1517
from .inaturalist import INaturalist
@@ -83,4 +85,6 @@
8385
"Food101",
8486
"DTD",
8587
"FER2013",
88+
"GTSRB",
89+
"CLEVRClassification",
8690
)

torchvision/datasets/clevr.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import json
2+
import pathlib
3+
from typing import Any, Callable, Optional, Tuple, List
4+
from urllib.parse import urlparse
5+
6+
from PIL import Image
7+
8+
from .utils import download_and_extract_archive, verify_str_arg
9+
from .vision import VisionDataset
10+
11+
12+
class CLEVRClassification(VisionDataset):
13+
"""`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_ classification dataset.
14+
15+
The number of objects in a scene are used as label.
16+
17+
Args:
18+
root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
19+
set to True.
20+
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
21+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
22+
version. E.g, ``transforms.RandomCrop``
23+
target_transform (callable, optional): A function/transform that takes in them target and transforms it.
24+
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
25+
dataset is already downloaded, it is not downloaded again.
26+
"""
27+
28+
_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
29+
_MD5 = "b11922020e72d0cd9154779b2d3d07d2"
30+
31+
def __init__(
32+
self,
33+
root: str,
34+
split: str = "train",
35+
transform: Optional[Callable] = None,
36+
target_transform: Optional[Callable] = None,
37+
download: bool = True,
38+
) -> None:
39+
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
40+
super().__init__(root, transform=transform, target_transform=target_transform)
41+
self._base_folder = pathlib.Path(self.root) / "clevr"
42+
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
43+
44+
if download:
45+
self._download()
46+
47+
if not self._check_exists():
48+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
49+
50+
self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))
51+
52+
self._labels: List[Optional[int]]
53+
if self._split != "test":
54+
with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
55+
content = json.load(file)
56+
num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
57+
self._labels = [num_objects[image_file.name] for image_file in self._image_files]
58+
else:
59+
self._labels = [None] * len(self._image_files)
60+
61+
def __len__(self) -> int:
62+
return len(self._image_files)
63+
64+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
65+
image_file = self._image_files[idx]
66+
label = self._labels[idx]
67+
68+
image = Image.open(image_file).convert("RGB")
69+
70+
if self.transform:
71+
image = self.transform(image)
72+
73+
if self.target_transform:
74+
label = self.target_transform(label)
75+
76+
return image, label
77+
78+
def _check_exists(self) -> bool:
79+
return self._data_folder.exists() and self._data_folder.is_dir()
80+
81+
def _download(self) -> None:
82+
if self._check_exists():
83+
return
84+
85+
download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)
86+
87+
def extra_repr(self) -> str:
88+
return f"split={self._split}"

torchvision/datasets/gtsrb.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import csv
2+
import os
3+
from typing import Any, Callable, Optional, Tuple
4+
5+
import PIL
6+
7+
from .folder import make_dataset
8+
from .utils import download_and_extract_archive
9+
from .vision import VisionDataset
10+
11+
12+
class GTSRB(VisionDataset):
13+
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
14+
15+
Args:
16+
root (string): Root directory of the dataset.
17+
train (bool, optional): If True, creates dataset from training set, otherwise
18+
creates from test set.
19+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
20+
version. E.g, ``transforms.RandomCrop``.
21+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
22+
download (bool, optional): If True, downloads the dataset from the internet and
23+
puts it in root directory. If dataset is already downloaded, it is not
24+
downloaded again.
25+
"""
26+
27+
# Ground Truth for the test set
28+
_gt_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip"
29+
_gt_csv = "GT-final_test.csv"
30+
_gt_md5 = "fe31e9c9270bbcd7b84b7f21a9d9d9e5"
31+
32+
# URLs for the test and train set
33+
_urls = (
34+
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip",
35+
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip",
36+
)
37+
38+
_md5s = ("c7e4e6327067d32654124b0fe9e82185", "513f3c79a4c5141765e10e952eaa2478")
39+
40+
def __init__(
41+
self,
42+
root: str,
43+
train: bool = True,
44+
transform: Optional[Callable] = None,
45+
target_transform: Optional[Callable] = None,
46+
download: bool = False,
47+
) -> None:
48+
49+
super().__init__(root, transform=transform, target_transform=target_transform)
50+
51+
self.root = os.path.expanduser(root)
52+
53+
self.train = train
54+
55+
self._base_folder = os.path.join(self.root, type(self).__name__)
56+
self._target_folder = os.path.join(self._base_folder, "Training" if self.train else "Final_Test/Images")
57+
58+
if download:
59+
self.download()
60+
61+
if not self._check_exists():
62+
raise RuntimeError("Dataset not found. You can use download=True to download it")
63+
64+
if train:
65+
samples = make_dataset(self._target_folder, extensions=(".ppm",))
66+
else:
67+
with open(os.path.join(self._base_folder, self._gt_csv)) as csv_file:
68+
samples = [
69+
(os.path.join(self._target_folder, row["Filename"]), int(row["ClassId"]))
70+
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
71+
]
72+
73+
self._samples = samples
74+
self.transform = transform
75+
self.target_transform = target_transform
76+
77+
def __len__(self) -> int:
78+
return len(self._samples)
79+
80+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
81+
82+
path, target = self._samples[index]
83+
sample = PIL.Image.open(path).convert("RGB")
84+
85+
if self.transform is not None:
86+
sample = self.transform(sample)
87+
88+
if self.target_transform is not None:
89+
target = self.target_transform(target)
90+
91+
return sample, target
92+
93+
def _check_exists(self) -> bool:
94+
return os.path.exists(self._target_folder) and os.path.isdir(self._target_folder)
95+
96+
def download(self) -> None:
97+
if self._check_exists():
98+
return
99+
100+
download_and_extract_archive(self._urls[self.train], download_root=self.root, md5=self._md5s[self.train])
101+
102+
if not self.train:
103+
# Download Ground Truth for the test set
104+
download_and_extract_archive(
105+
self._gt_url, download_root=self.root, extract_root=self._base_folder, md5=self._gt_md5
106+
)

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .caltech import Caltech101, Caltech256
22
from .celeba import CelebA
33
from .cifar import Cifar10, Cifar100
4+
from .clevr import CLEVR
45
from .coco import Coco
56
from .dtd import DTD
67
from .fer2013 import FER2013

0 commit comments

Comments
 (0)