Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
EMNIST
FakeData
FashionMNIST
FER2013
Flickr8k
Flickr30k
FlyingChairs
Expand Down
34 changes: 34 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import bz2
import contextlib
import csv
import io
import itertools
import json
Expand Down Expand Up @@ -2168,5 +2169,38 @@ def inject_fake_data(self, tmpdir, config):
return num_sequences * (num_examples_per_sequence - 1)


class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FER2013
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))

FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))

def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, "fer2013")
os.makedirs(base_folder)

num_samples = 5
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
writer = csv.DictWriter(
file,
fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",),
quoting=csv.QUOTE_NONNUMERIC,
quotechar='"',
)
writer.writeheader()
for _ in range(num_samples):
row = dict(
pixels=" ".join(
str(pixel) for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
)
)
if config["split"] == "train":
row["emotion"] = str(int(torch.randint(0, 7, ())))

writer.writerow(row)

return num_samples


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .cityscapes import Cityscapes
from .coco import CocoCaptions, CocoDetection
from .fakedata import FakeData
from .fer2013 import FER2013
from .flickr import Flickr8k, Flickr30k
from .folder import ImageFolder, DatasetFolder
from .hmdb51 import HMDB51
Expand Down Expand Up @@ -77,4 +78,5 @@
"FlyingChairs",
"FlyingThings3D",
"HD1K",
"FER2013",
)
78 changes: 78 additions & 0 deletions torchvision/datasets/fer2013.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import csv
import os
import os.path
from typing import Any, Callable, Optional, Tuple

import torch
from PIL import Image

from .utils import verify_str_arg, check_integrity
from .vision import VisionDataset


class FER2013(VisionDataset):
"""`FER2013
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""

_RESOURCES = {
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
}

def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self._split = verify_str_arg(split, "split", self._RESOURCES.keys())
super().__init__(root, transform=transform, target_transform=target_transform)

with open(self._verify_integrity(), "r", newline="") as file:
self._samples = [
(
torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
int(row["emotion"]) if "emotion" in row else None,
)
for row in csv.DictReader(file)
]

def __len__(self) -> int:
return len(self._samples)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_tensor, target = self._samples[idx]
image = Image.fromarray(image_tensor.numpy())

if self.transform is not None:
image = self.transform(image)

if self.target_transform is not None:
target = self.target_transform(target)

return image, target

def _verify_integrity(self):
base_folder = os.path.join(self.root, type(self).__name__.lower())
file_name, md5 = self._RESOURCES[self._split]
file = os.path.join(base_folder, file_name)
if not check_integrity(file, md5=md5):
raise RuntimeError(
f"{file_name} not found in {base_folder} or corrupted. "
f"You can download it from "
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
)
return file

def extra_repr(self) -> str:
return f"split={self._split}"
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .coco import Coco
from .fer2013 import FER2013
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .sbd import SBD
Expand Down
80 changes: 80 additions & 0 deletions torchvision/prototype/datasets/_builtin/fer2013.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Union, cast

import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
DatasetType,
KaggleDownloadResource,
)
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
image_buffer_from_array,
)
from torchvision.prototype.features import Label, Image


class FER2013(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fer2013",
type=DatasetType.RAW,
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"),
valid_options=dict(split=("train", "test")),
)

_CHECKSUMS = {
"train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10",
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = KaggleDownloadResource(
cast(str, self.info.homepage),
file_name=f"{config.split}.csv.zip",
sha256=self._CHECKSUMS[config.split],
)
return [archive]

def _collate_and_decode_sample(
self,
data: Dict[str, Any],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
raw_image = torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)
label_id = data.get("emotion")
label_idx = int(label_id) if label_id is not None else None

image: Union[Image, io.BytesIO]
if decoder is raw:
image = Image(raw_image)
else:
image_buffer = image_buffer_from_array(raw_image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]

return dict(
image=image,
label=Label(label_idx, category=self.info.categories[label_idx]) if label_idx is not None else None,
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVDictParser(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import _internal
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
14 changes: 14 additions & 0 deletions torchvision/prototype/datasets/utils/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,17 @@ def _download(self, root: pathlib.Path) -> NoReturn:
f"Please follow the instructions below and place it in {root}\n\n"
f"{self.instructions}"
)


class KaggleDownloadResource(ManualDownloadResource):
def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None:
instructions = "\n".join(
(
"1. Register and login at https://www.kaggle.com",
f"2. Navigate to {challenge_url}",
"3. Click 'Join Competition' and follow the instructions there",
"4. Navigate to the 'Data' tab",
f"5. Select {file_name} in the 'Data Explorer' and click the download button",
)
)
super().__init__(instructions, file_name=file_name, **kwargs)