diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 7f09ff245ca..3a2872a6388 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes `_. + + The Food-101 is a challenging data set of 101 food categories, with 101'000 images. + For each class, 250 manually reviewed test images are provided as well as 750 training images. + On purpose, the training images were not cleaned, and thus still contain some amount of noise. + This comes mostly in the form of intense colors and sometimes wrong labels. All images were + rescaled to have a maximum side length of 512 pixels. + + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" + _MD5 = "85eeb15f3717b99a5da872d97d918f87" + + def __init__( + self, + root: str, + split: str = "train", + download: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._split = verify_str_arg(split, "split", ("train", "test")) + self._base_folder = Path(self.root) / "food-101" + self._meta_folder = self._base_folder / "meta" + self._images_folder = self._base_folder / "images" + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._labels = [] + self._image_files = [] + with open(self._meta_folder / f"{split}.json", "r") as f: + metadata = json.loads(f.read()) + + self.classes = sorted(metadata.keys()) + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + + for class_label, im_rel_paths in metadata.items(): + self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths) + self._image_files += [ + self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths + ] + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def extra_repr(self) -> str: + return f"split={self._split}" + + def _check_exists(self) -> bool: + return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder)) + + def _download(self) -> None: + if self._check_exists(): + return + download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)