Skip to content
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FlyingChairs
FlyingThings3D
Food101
FGVCAircraft
GTSRB
HD1K
HMDB51
Expand Down
46 changes: 46 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,6 +2206,52 @@ def inject_fake_data(self, tmpdir: str, config):
return len(sampled_classes * n_samples_per_class)


class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FGVCAircraft
FEATURE_TYPES = (PIL.Image.Image, int)

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "trainval", "test"))

def inject_fake_data(self, tmpdir: str, config):
split = config["split"]
root_folder = pathlib.Path(tmpdir) / "fgvc-aircraft-2013b"
data_folder = root_folder / "data"

num_images_per_class = 5
variants = ["707-320", "Hawk T1", "Tornado"]
num_samples_per_class = 4 if split == "trainval" else 2

datasets_utils.create_image_folder(
data_folder,
"images",
file_name_fn=lambda idx: f"{idx}.jpg",
num_examples=num_images_per_class * len(variants),
)

images_variants = []
for i in range(len(variants)):
variant = variants[i]
images_variants.extend(
[
f"{idx} {variant}"
for idx in random.sample(
range(i * num_images_per_class, (i + 1) * num_images_per_class), num_samples_per_class
)
]
)

variants_file = root_folder / "data" / "variants.txt"
images_variant_file = root_folder / "data" / f"images_variant_{split}.txt"

with open(variants_file, "w") as file:
file.write("\n".join(variants))

with open(images_variant_file, "w") as file:
file.write("\n".join(images_variants))

return len(variants * num_samples_per_class)


class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397

Expand Down
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .dtd import DTD
from .fakedata import FakeData
from .fer2013 import FER2013
from .fgvc_aircraft import FGVCAircraft
from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder
from .food101 import Food101
Expand Down Expand Up @@ -91,4 +92,5 @@
"GTSRB",
"CLEVRClassification",
"OxfordIIITPet",
"FGVCAircraft",
)
106 changes: 106 additions & 0 deletions torchvision/datasets/fgvc_aircraft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
from typing import Any, Callable, List, Optional, Tuple

import PIL.Image

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset


class FGVCAircraft(VisionDataset):
"""`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
The dataset contains 10,200 images of aircraft, with 100 images for each of 102
different aircraft model variants, most of which are airplanes.
Args:
root (string): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/"
_URL_FILE = "fgvc-aircraft-2013b.tar.gz"

def __init__(
self,
root: str,
split: str = "trainval",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))

self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
if download:
self._download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

self._label_names = self._get_label_names(self._data_path)

# Parse the downloaded files
self._image_folder = os.path.join(self.root, self._split)

self._image_files = []
self._labels = []
self._label_name_to_idx = dict(zip(self._label_names, range(len(self._label_names))))

self._read_fgvc_aircrafts_images_labels(self._data_path)

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")

if self.transform:
image = self.transform(image)

if self.target_transform:
label = self.target_transform(label)

return image, label

def _download(self):
"""
Download the FGVC Aircraft dataset archive and extract it under root.
"""
if self._check_exists():
return
download_and_extract_archive(self._URL + self._URL_FILE, self.root)

def _check_exists(self) -> bool:
return os.path.exists(self._data_path) and os.path.isdir(self._data_path)

def _read_fgvc_aircrafts_images_labels(self, input_path: str):
image_data_folder = os.path.join(input_path, "data", "images")
labels_path = os.path.join(input_path, "data", f"images_variant_{self._split}.txt")

with open(labels_path, "r") as labels_file:
lines = [line.strip() for line in labels_file]
for line in lines:
line_list = line.split(" ")
image_name = line_list[0]
label_name = self._parse_aircraft_name(" ".join(line_list[1:]))
self._labels.append(self._label_name_to_idx[label_name])
self._image_files.append(os.path.join(image_data_folder, image_name + str(".jpg")))

def _get_label_names(self, input_path: str) -> List[str]:
variants_file = os.path.join(input_path, "data", "variants.txt")
with open(variants_file, "r") as f:
return [self._parse_aircraft_name(line.strip()) for line in f]

def _parse_aircraft_name(self, name: str) -> str:
return name.replace("/", "-").replace(" ", "-")