|
| 1 | +import csv |
| 2 | +import pathlib |
| 3 | +from typing import Any, Callable, Optional, Tuple |
| 4 | + |
| 5 | +import torch |
| 6 | +from PIL import Image |
| 7 | + |
| 8 | +from .utils import verify_str_arg, check_integrity |
| 9 | +from .vision import VisionDataset |
| 10 | + |
| 11 | + |
| 12 | +class FER2013(VisionDataset): |
| 13 | + """`FER2013 |
| 14 | + <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset. |
| 15 | +
|
| 16 | + Args: |
| 17 | + root (string): Root directory of dataset where directory |
| 18 | + ``root/fer2013`` exists. |
| 19 | + split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. |
| 20 | + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed |
| 21 | + version. E.g, ``transforms.RandomCrop`` |
| 22 | + target_transform (callable, optional): A function/transform that takes in the target and transforms it. |
| 23 | + """ |
| 24 | + |
| 25 | + _RESOURCES = { |
| 26 | + "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"), |
| 27 | + "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"), |
| 28 | + } |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + root: str, |
| 33 | + split: str = "train", |
| 34 | + transform: Optional[Callable] = None, |
| 35 | + target_transform: Optional[Callable] = None, |
| 36 | + ) -> None: |
| 37 | + self._split = verify_str_arg(split, "split", self._RESOURCES.keys()) |
| 38 | + super().__init__(root, transform=transform, target_transform=target_transform) |
| 39 | + |
| 40 | + base_folder = pathlib.Path(self.root) / "fer2013" |
| 41 | + file_name, md5 = self._RESOURCES[self._split] |
| 42 | + data_file = base_folder / file_name |
| 43 | + if not check_integrity(str(data_file), md5=md5): |
| 44 | + raise RuntimeError( |
| 45 | + f"{file_name} not found in {base_folder} or corrupted. " |
| 46 | + f"You can download it from " |
| 47 | + f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" |
| 48 | + ) |
| 49 | + |
| 50 | + with open(data_file, "r", newline="") as file: |
| 51 | + self._samples = [ |
| 52 | + ( |
| 53 | + torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48), |
| 54 | + int(row["emotion"]) if "emotion" in row else None, |
| 55 | + ) |
| 56 | + for row in csv.DictReader(file) |
| 57 | + ] |
| 58 | + |
| 59 | + def __len__(self) -> int: |
| 60 | + return len(self._samples) |
| 61 | + |
| 62 | + def __getitem__(self, idx: int) -> Tuple[Any, Any]: |
| 63 | + image_tensor, target = self._samples[idx] |
| 64 | + image = Image.fromarray(image_tensor.numpy()) |
| 65 | + |
| 66 | + if self.transform is not None: |
| 67 | + image = self.transform(image) |
| 68 | + |
| 69 | + if self.target_transform is not None: |
| 70 | + target = self.target_transform(target) |
| 71 | + |
| 72 | + return image, target |
| 73 | + |
| 74 | + def extra_repr(self) -> str: |
| 75 | + return f"split={self._split}" |
0 commit comments