Skip to content

Commit 2373376

Browse files
prabhat00155NicolasHug
authored andcommitted
[fbsync] FER2013 dataset (#5120)
Summary: * add prototype dataset * add old style dataset * Apply suggestions from code review * refactor integrity check Reviewed By: sallysyw Differential Revision: D33479268 fbshipit-source-id: 239a0efb550b21ce0c39ed94caf7436d89813f65 Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent fc48f9e commit 2373376

File tree

8 files changed

+208
-1
lines changed

8 files changed

+208
-1
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4242
EMNIST
4343
FakeData
4444
FashionMNIST
45+
FER2013
4546
Flickr8k
4647
Flickr30k
4748
FlyingChairs

test/test_datasets.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import bz2
22
import contextlib
3+
import csv
34
import io
45
import itertools
56
import json
@@ -2241,5 +2242,38 @@ def inject_fake_data(self, tmpdir: str, config):
22412242
return len(image_ids_in_config)
22422243

22432244

2245+
class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
2246+
DATASET_CLASS = datasets.FER2013
2247+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
2248+
2249+
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
2250+
2251+
def inject_fake_data(self, tmpdir, config):
2252+
base_folder = os.path.join(tmpdir, "fer2013")
2253+
os.makedirs(base_folder)
2254+
2255+
num_samples = 5
2256+
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
2257+
writer = csv.DictWriter(
2258+
file,
2259+
fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",),
2260+
quoting=csv.QUOTE_NONNUMERIC,
2261+
quotechar='"',
2262+
)
2263+
writer.writeheader()
2264+
for _ in range(num_samples):
2265+
row = dict(
2266+
pixels=" ".join(
2267+
str(pixel) for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
2268+
)
2269+
)
2270+
if config["split"] == "train":
2271+
row["emotion"] = str(int(torch.randint(0, 7, ())))
2272+
2273+
writer.writerow(row)
2274+
2275+
return num_samples
2276+
2277+
22442278
if __name__ == "__main__":
22452279
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .coco import CocoCaptions, CocoDetection
77
from .dtd import DTD
88
from .fakedata import FakeData
9+
from .fer2013 import FER2013
910
from .flickr import Flickr8k, Flickr30k
1011
from .folder import ImageFolder, DatasetFolder
1112
from .food101 import Food101
@@ -81,4 +82,5 @@
8182
"HD1K",
8283
"Food101",
8384
"DTD",
85+
"FER2013",
8486
)

torchvision/datasets/fer2013.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import csv
2+
import pathlib
3+
from typing import Any, Callable, Optional, Tuple
4+
5+
import torch
6+
from PIL import Image
7+
8+
from .utils import verify_str_arg, check_integrity
9+
from .vision import VisionDataset
10+
11+
12+
class FER2013(VisionDataset):
13+
"""`FER2013
14+
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
15+
16+
Args:
17+
root (string): Root directory of dataset where directory
18+
``root/fer2013`` exists.
19+
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
20+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
21+
version. E.g, ``transforms.RandomCrop``
22+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
23+
"""
24+
25+
_RESOURCES = {
26+
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
27+
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
28+
}
29+
30+
def __init__(
31+
self,
32+
root: str,
33+
split: str = "train",
34+
transform: Optional[Callable] = None,
35+
target_transform: Optional[Callable] = None,
36+
) -> None:
37+
self._split = verify_str_arg(split, "split", self._RESOURCES.keys())
38+
super().__init__(root, transform=transform, target_transform=target_transform)
39+
40+
base_folder = pathlib.Path(self.root) / "fer2013"
41+
file_name, md5 = self._RESOURCES[self._split]
42+
data_file = base_folder / file_name
43+
if not check_integrity(str(data_file), md5=md5):
44+
raise RuntimeError(
45+
f"{file_name} not found in {base_folder} or corrupted. "
46+
f"You can download it from "
47+
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
48+
)
49+
50+
with open(data_file, "r", newline="") as file:
51+
self._samples = [
52+
(
53+
torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
54+
int(row["emotion"]) if "emotion" in row else None,
55+
)
56+
for row in csv.DictReader(file)
57+
]
58+
59+
def __len__(self) -> int:
60+
return len(self._samples)
61+
62+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
63+
image_tensor, target = self._samples[idx]
64+
image = Image.fromarray(image_tensor.numpy())
65+
66+
if self.transform is not None:
67+
image = self.transform(image)
68+
69+
if self.target_transform is not None:
70+
target = self.target_transform(target)
71+
72+
return image, target
73+
74+
def extra_repr(self) -> str:
75+
return f"split={self._split}"

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .cifar import Cifar10, Cifar100
44
from .coco import Coco
55
from .dtd import DTD
6+
from .fer2013 import FER2013
67
from .imagenet import ImageNet
78
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
89
from .sbd import SBD
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import functools
2+
import io
3+
from typing import Any, Callable, Dict, List, Optional, Union, cast
4+
5+
import torch
6+
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
7+
from torchvision.prototype.datasets.decoder import raw
8+
from torchvision.prototype.datasets.utils import (
9+
Dataset,
10+
DatasetConfig,
11+
DatasetInfo,
12+
OnlineResource,
13+
DatasetType,
14+
KaggleDownloadResource,
15+
)
16+
from torchvision.prototype.datasets.utils._internal import (
17+
hint_sharding,
18+
hint_shuffling,
19+
image_buffer_from_array,
20+
)
21+
from torchvision.prototype.features import Label, Image
22+
23+
24+
class FER2013(Dataset):
25+
def _make_info(self) -> DatasetInfo:
26+
return DatasetInfo(
27+
"fer2013",
28+
type=DatasetType.RAW,
29+
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
30+
categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"),
31+
valid_options=dict(split=("train", "test")),
32+
)
33+
34+
_CHECKSUMS = {
35+
"train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10",
36+
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
37+
}
38+
39+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
40+
archive = KaggleDownloadResource(
41+
cast(str, self.info.homepage),
42+
file_name=f"{config.split}.csv.zip",
43+
sha256=self._CHECKSUMS[config.split],
44+
)
45+
return [archive]
46+
47+
def _collate_and_decode_sample(
48+
self,
49+
data: Dict[str, Any],
50+
*,
51+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
52+
) -> Dict[str, Any]:
53+
raw_image = torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)
54+
label_id = data.get("emotion")
55+
label_idx = int(label_id) if label_id is not None else None
56+
57+
image: Union[Image, io.BytesIO]
58+
if decoder is raw:
59+
image = Image(raw_image)
60+
else:
61+
image_buffer = image_buffer_from_array(raw_image.numpy())
62+
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
63+
64+
return dict(
65+
image=image,
66+
label=Label(label_idx, category=self.info.categories[label_idx]) if label_idx is not None else None,
67+
)
68+
69+
def _make_datapipe(
70+
self,
71+
resource_dps: List[IterDataPipe],
72+
*,
73+
config: DatasetConfig,
74+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
75+
) -> IterDataPipe[Dict[str, Any]]:
76+
dp = resource_dps[0]
77+
dp = CSVDictParser(dp)
78+
dp = hint_sharding(dp)
79+
dp = hint_shuffling(dp)
80+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from . import _internal
22
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
33
from ._query import SampleQuery
4-
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource
4+
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource

torchvision/prototype/datasets/utils/_resource.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,17 @@ def _download(self, root: pathlib.Path) -> NoReturn:
176176
f"Please follow the instructions below and place it in {root}\n\n"
177177
f"{self.instructions}"
178178
)
179+
180+
181+
class KaggleDownloadResource(ManualDownloadResource):
182+
def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None:
183+
instructions = "\n".join(
184+
(
185+
"1. Register and login at https://www.kaggle.com",
186+
f"2. Navigate to {challenge_url}",
187+
"3. Click 'Join Competition' and follow the instructions there",
188+
"4. Navigate to the 'Data' tab",
189+
f"5. Select {file_name} in the 'Data Explorer' and click the download button",
190+
)
191+
)
192+
super().__init__(instructions, file_name=file_name, **kwargs)

0 commit comments

Comments
 (0)