diff --git a/docs/datasets.md b/docs/datasets.md new file mode 100644 index 00000000..b698370f --- /dev/null +++ b/docs/datasets.md @@ -0,0 +1,180 @@ +# Datasets + +Datasets classes give you a way to automatically download a dataset and transform it into a PyTorch dataset. + +All implemented datasets have disjoint train-test splits, ideal for benchmarking on image retrieval and one-shot/few-shot classification tasks. + +## BaseDataset + +All dataset classes extend this class and therefore inherit its ```__init__``` parameters. + +```python +datasets.base_dataset.BaseDataset( + root, + split="train+test", + transform=None, + target_transform=None, + download=False +) +``` + +**Parameters**: + +* **root**: The path where the dataset files are saved. +* **split**: A string that determines which split of the dataset is loaded. +* **transform**: A `torchvision.transforms` object which will be used on the input images. +* **target_transform**: A `torchvision.transforms` object which will be used on the labels. +* **download**: Whether to download the dataset or not. Setting this as False, but not having the dataset on the disk will raise a ValueError. + +**Required Implementations**: +```python + @abstractmethod + def download_and_remove(): + raise NotImplementedError + + @abstractmethod + def generate_split(): + raise NotImplementedError +``` + +## CUB-200-2011 + +```python +datasets.cub.CUB(*args, **kwargs) +``` + +**Defined splits**: + +- `train` - Consists of 5864 examples, taken from classes 1 to 100. +- `test` - Consists of 5924 examples, taken from classes 101 to 200. +- `train+test` - Consists 11788 of examples, taken from all classes. + +**Loading different dataset splits** +```python +train_dataset = CUB(root="data", + split="train", + transform=None, + target_transform=None, + download=True +) +# No need to download the dataset after it is already downladed +test_dataset = CUB(root="data", + split="test", + transform=None, + target_transform=None, + download=False +) +train_and_test_dataset = CUB(root="data", + split="train+test", + transform=None, + target_transform=None, + download=False +) +``` + +## Cars196 + +```python +datasets.cars196.Cars196(*args, **kwargs) +``` + +**Defined splits**: + +- `train` - Consists of 8054 examples, taken from classes 1 to 99. +- `test` - Consists of 8131 examples, taken from classes 99 to 197. +- `train+test` - Consists of 16185 examples, taken from all classes. + +**Loading different dataset splits** +```python +train_dataset = Cars196(root="data", + split="train", + transform=None, + target_transform=None, + download=True +) +# No need to download the dataset after it is already downladed +test_dataset = Cars196(root="data", + split="test", + transform=None, + target_transform=None, + download=False +) +train_and_test_dataset = Cars196(root="data", + split="train+test", + transform=None, + target_transform=None, + download=False +) +``` + +## INaturalist2018 + +```python +datasets.inaturalist2018.INaturalist2018(*args, **kwargs) +``` + +**Defined splits**: + +- `train` - Consists of 325 846 examples. +- `test` - Consists of 136 093 examples. +- `train+test` - Consists of 461 939 examples. + +**Loading different dataset splits** +```python +# The download takes a while - the dataset is very large +train_dataset = INaturalist2018(root="data", + split="train", + transform=None, + target_transform=None, + download=True +) +# No need to download the dataset after it is already downladed +test_dataset = INaturalist2018(root="data", + split="test", + transform=None, + target_transform=None, + download=False +) +train_and_test_dataset = INaturalist2018(root="data", + split="train+test", + transform=None, + target_transform=None, + download=False +) +``` + +## StanfordOnlineProducts + +```python +datasets.sop.StanfordOnlineProducts(*args, **kwargs) +``` + +**Defined splits**: + +- `train` - Consists of 59551 examples. +- `test` - Consists of 60502 examples. +- `train+test` - Consists of 120 053 examples. + +**Loading different dataset splits** +```python +# The download takes a while - the dataset is very large +train_dataset = StanfordOnlineProducts(root="data", + split="train", + transform=None, + target_transform=None, + download=True +) +# No need to download the dataset after it is already downladed +test_dataset = StanfordOnlineProducts(root="data", + split="test", + transform=None, + target_transform=None, + download=False +) +train_and_test_dataset = StanfordOnlineProducts(root="data", + split="train+test", + transform=None, + target_transform=None, + download=False +) +``` diff --git a/docs/extend/datasets.md b/docs/extend/datasets.md new file mode 100644 index 00000000..606fecab --- /dev/null +++ b/docs/extend/datasets.md @@ -0,0 +1,34 @@ +# How to write custom datasets + +1. Subclass the ```datasets.base_dataset.BaseDataset``` class +2. Add implementations for abstract methods from the base class: + - ```download_and_remove()``` + - ```generate_split()``` + + +```python +from pytorch_metric_learning.datasets.base_dataset import BaseDataset + +class MyDataset(BaseDataset): + + def __init__(self, my_parameter, *args, **kwargs): + super().__init__(*args, **kwargs) + self.my_parameter = self.my_parameter + + def download_and_remove(self): + # Downloads the dataset files needed + # + # If you're using a dataset that you've already downloaded elsewhere, + # just use an empty implementation + pass + + def generate_split(self): + # Creates a list of image paths, and saves them into self.paths + # Creates a list of labels for the images, and saves them into self.labels + # + # The default training splits that need to be covered are `train`, `test`, and `train+test` + # If you need a different split setup, override `get_available_splits(self)` to return + # the split names you want + pass + +``` diff --git a/mkdocs.yml b/mkdocs.yml index a05c6d4c..731fc84c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,6 +1,7 @@ site_name: PyTorch Metric Learning nav: - Home: index.md + - Datasets: datasets.md - Distances: distances.md - Losses: losses.md - Miners: miners.md @@ -16,6 +17,7 @@ nav: - Common Functions: common_functions.md - Distributed: distributed.md - How to extend this library: + - Custom datasets: extend/datasets.md - Custom losses: extend/losses.md - Custom miners: extend/miners.md - Frequently Asked Questions: faq.md diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index 2614ce9d..892994aa 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.7.0" +__version__ = "2.8.0" diff --git a/src/pytorch_metric_learning/datasets/base_dataset.py b/src/pytorch_metric_learning/datasets/base_dataset.py new file mode 100644 index 00000000..c9613b99 --- /dev/null +++ b/src/pytorch_metric_learning/datasets/base_dataset.py @@ -0,0 +1,71 @@ +import os +from abc import ABC, abstractmethod + +from PIL import Image +from torch.utils.data import Dataset + + +class BaseDataset(ABC, Dataset): + + def __init__( + self, + root, + split="train+test", + transform=None, + target_transform=None, + download=False, + ): + self.root = root + + if download: + if not os.path.isdir(self.root): + os.makedirs(self.root, exist_ok=False) + self.download_and_remove() + elif os.listdir(self.root) == []: + self.download_and_remove() + else: + # The given directory does not exist so the user should be aware of downloading it + # Otherwise proceed as usual + if not os.path.isdir(self.root): + raise ValueError( + "The given path does not exist. " + "You should probably initialize the dataset with download=True." + ) + + self.transform = transform + self.target_transform = target_transform + + if split not in self.get_available_splits(): + raise ValueError( + f"Supported splits are: {', '.join(self.get_available_splits())}" + ) + + self.split = split + + self.generate_split() + + @abstractmethod + def generate_split(): + raise NotImplementedError + + @abstractmethod + def download_and_remove(): + raise NotImplementedError + + def get_available_splits(self): + return ["train", "test", "train+test"] + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + img = Image.open(self.paths[idx]) + label = self.labels[idx] + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + label = self.target_transform(label) + + return (img, label) diff --git a/src/pytorch_metric_learning/datasets/cars196.py b/src/pytorch_metric_learning/datasets/cars196.py new file mode 100644 index 00000000..9e71e6aa --- /dev/null +++ b/src/pytorch_metric_learning/datasets/cars196.py @@ -0,0 +1,67 @@ +import os +import zipfile + +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve + + +class Cars196(BaseDataset): + + DOWNLOAD_URL = "https://www.kaggle.com/api/v1/datasets/download/jutrera/stanford-car-dataset-by-classes-folder" + + def generate_split(self): + # Training set is first 99 classes, test is other classes + if self.split == "train": + classes = set(range(1, 99)) + elif self.split == "test": + classes = set(range(99, 197)) + else: + classes = set(range(1, 197)) + + with open(os.path.join(self.root, "names.csv"), "r") as f: + names = [x.strip() for x in f.readlines()] + + paths_train, labels_train = self._load_csv( + os.path.join(self.root, "anno_train.csv"), names, split="train" + ) + paths_test, labels_test = self._load_csv( + os.path.join(self.root, "anno_test.csv"), names, split="test" + ) + paths = paths_train + paths_test + labels = labels_train + labels_test + + self.paths, self.labels = [], [] + for p, l in zip(paths, labels): + if l in classes: + self.paths.append(p) + self.labels.append(l) + + def _load_csv(self, path, names, split): + all_paths, all_labels = [], [] + with open(path, "r") as f: + for l in f: + path_annos = l.split(",") + curr_path = path_annos[0] + curr_label = path_annos[-1] + all_paths.append( + os.path.join( + self.root, + "car_data", + "car_data", + split, + names[int(curr_label) - 1].replace("/", "-"), + curr_path, + ) + ) + all_labels.append(int(curr_label)) + return all_paths, all_labels + + def download_and_remove(self): + os.makedirs(self.root, exist_ok=True) + download_folder_path = os.path.join( + self.root, Cars196.DOWNLOAD_URL.split("/")[-1] + ) + _urlretrieve(url=Cars196.DOWNLOAD_URL, filename=download_folder_path) + with zipfile.ZipFile(download_folder_path, "r") as zip_ref: + zip_ref.extractall(self.root) + os.remove(download_folder_path) diff --git a/src/pytorch_metric_learning/datasets/cub.py b/src/pytorch_metric_learning/datasets/cub.py new file mode 100644 index 00000000..61f0d664 --- /dev/null +++ b/src/pytorch_metric_learning/datasets/cub.py @@ -0,0 +1,49 @@ +import os +import tarfile + +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve + + +class CUB(BaseDataset): + + DOWNLOAD_URL = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" + + def generate_split(self): + dir_name = CUB.DOWNLOAD_URL.split("/")[-1].replace(".tgz", "") + + # Training split is first 100 classes, other 100 is test + if self.split == "train": + classes = set(range(1, 101)) + elif self.split == "test": + classes = set(range(101, 201)) + else: + classes = set(range(1, 201)) + + # Find ids which correspond to the classes in the split + self.paths, self.labels = [], [] + with open(os.path.join(self.root, dir_name, "image_class_labels.txt")) as f1: + with open(os.path.join(self.root, dir_name, "images.txt")) as f2: + for l1, l2 in zip(f1, f2): + img_idx1, class_idx = list(map(int, l1.split())) + + if class_idx not in classes: + continue + + img_idx2, img_path = l2.split() + img_idx2 = int(img_idx2) + + # If the image ids correspond it's a match + if img_idx1 == img_idx2: + self.paths.append( + os.path.join(self.root, dir_name, "images", img_path) + ) + self.labels.append(class_idx) + + def download_and_remove(self): + os.makedirs(self.root, exist_ok=True) + download_folder_path = os.path.join(self.root, CUB.DOWNLOAD_URL.split("/")[-1]) + _urlretrieve(url=CUB.DOWNLOAD_URL, filename=download_folder_path) + with tarfile.open(download_folder_path, "r:gz") as tar: + tar.extractall(self.root) + os.remove(download_folder_path) diff --git a/src/pytorch_metric_learning/datasets/inaturalist2018.py b/src/pytorch_metric_learning/datasets/inaturalist2018.py new file mode 100644 index 00000000..478d8a20 --- /dev/null +++ b/src/pytorch_metric_learning/datasets/inaturalist2018.py @@ -0,0 +1,100 @@ +import json +import os +import tarfile +import zipfile + +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve + + +class INaturalist2018(BaseDataset): + + IMG_DOWNLOAD_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz" + TRAIN_ANN_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train2018.json.tar.gz" + VAL_ANN_URL = ( + "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/val2018.json.tar.gz" + ) + + # Link from https://github.com/Andrew-Brown1/Smooth_AP?tab=readme-ov-file#data + SPLITS_URL = "https://drive.google.com/uc?id=1sXfkBTFDrRU3__-NUs1qBP3sf_0uMB98" + + def generate_split(self): + with open(os.path.join(self.root, "train2018.json"), "r") as train_f: + train_json = json.load(train_f) + + with open(os.path.join(self.root, "val2018.json"), "r") as val_f: + val_json = json.load(val_f) + + val_imgs, val_anns = val_json["images"], val_json["annotations"] + train_imgs, train_anns = train_json["images"], train_json["annotations"] + + imgs, anns = val_imgs + train_imgs, val_anns + train_anns + + path2id = {x["file_name"]: x["id"] for x in imgs} + id2label = {x["image_id"]: x["category_id"] for x in anns} + + if self.split in ["train", "test"]: + paths = self._load_split_txt(self.split) + ids = [path2id[p] for p in paths] + labels = [id2label[i] for i in ids] + + elif self.split == "train+test": + train_paths = self._load_split_txt("train") + train_ids = [path2id[p] for p in train_paths] + train_labels = [id2label[i] for i in train_ids] + + test_paths = self._load_split_txt("test") + test_ids = [path2id[p] for p in test_paths] + test_labels = [id2label[i] for i in test_ids] + + paths = train_paths + test_paths + labels = train_labels + test_labels + + self.paths = [os.path.join(self.root, p) for p in paths] + self.labels = labels + + def _load_split_txt(self, split): + paths = [] + with open( + os.path.join( + self.root, "Inat_dataset_splits", f"Inaturalist_{split}_set1.txt" + ) + ) as f: + for l in f: + paths.append(l.strip()) + return paths + + def download_and_remove(self): + download_folder_path = os.path.join( + self.root, INaturalist2018.IMG_DOWNLOAD_URL.split("/")[-1] + ) + _urlretrieve( + url=INaturalist2018.IMG_DOWNLOAD_URL, filename=download_folder_path + ) + with tarfile.open(download_folder_path, "r:gz") as tar: + tar.extractall(self.root) + os.remove(download_folder_path) + + download_folder_path = os.path.join( + self.root, INaturalist2018.TRAIN_ANN_URL.split("/")[-1] + ) + _urlretrieve(url=INaturalist2018.TRAIN_ANN_URL, filename=download_folder_path) + with tarfile.open(download_folder_path, "r:gz") as tar: + tar.extractall(self.root) + os.remove(download_folder_path) + + download_folder_path = os.path.join( + self.root, INaturalist2018.VAL_ANN_URL.split("/")[-1] + ) + _urlretrieve(url=INaturalist2018.VAL_ANN_URL, filename=download_folder_path) + with tarfile.open(download_folder_path, "r:gz") as tar: + tar.extractall(self.root) + os.remove(download_folder_path) + + download_folder_path = os.path.join( + self.root, INaturalist2018.SPLITS_URL.split("/")[-1] + ) + _urlretrieve(url=INaturalist2018.SPLITS_URL, filename=download_folder_path) + with zipfile.ZipFile(download_folder_path, "r") as zip_ref: + zip_ref.extractall(self.root) + os.remove(download_folder_path) diff --git a/src/pytorch_metric_learning/datasets/sop.py b/src/pytorch_metric_learning/datasets/sop.py new file mode 100644 index 00000000..cf298a54 --- /dev/null +++ b/src/pytorch_metric_learning/datasets/sop.py @@ -0,0 +1,55 @@ +import os +import zipfile + +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve + + +class StanfordOnlineProducts(BaseDataset): + + # Link from https://github.com/rksltnl/Deep-Metric-Learning-CVPR16?tab=readme-ov-file#stanford-online-products-dataset + DOWNLOAD_URL = "https://drive.usercontent.google.com/download?id=1TclrpQOF_ullUP99wk_gjGN8pKvtErG8&export=download&authuser=0&confirm=t" + + def generate_split(self): + if self.split in ["train", "test"]: + paths, labels = self._load_split_txt(self.split) + elif self.split == "train+test": + train_paths, train_labels = self._load_split_txt("train") + test_paths, test_labels = self._load_split_txt("test") + paths = train_paths + test_paths + labels = train_labels + test_labels + + self.paths = paths + self.labels = labels + + def _load_split_txt(self, split): + paths, labels = [], [] + with open( + os.path.join(self.root, "Stanford_Online_Products", f"Ebay_{split}.txt") + ) as f: + for i, l in enumerate(f): + if i == 0: + continue + l_split = l.strip().split() + label, path = int(l_split[1]), l_split[3] + paths.append(os.path.join(self.root, "Stanford_Online_Products", path)) + labels.append(label) + return paths, labels + + def download_and_remove(self): + os.makedirs(self.root, exist_ok=True) + download_folder_path = os.path.join( + self.root, StanfordOnlineProducts.DOWNLOAD_URL.split("/")[-1] + ) + _urlretrieve( + url=StanfordOnlineProducts.DOWNLOAD_URL, filename=download_folder_path + ) + with zipfile.ZipFile(download_folder_path, "r") as zip_ref: + zip_ref.extractall(self.root) + os.remove(download_folder_path) + + +# if __name__ == "__main__": +# train_dataset = StanfordOnlineProducts(root="data_sop", split="train", download=True) +# train_dataset = StanfordOnlineProducts(root="data_sop", split="test", download=True) +# train_dataset = StanfordOnlineProducts(root="data_sop", split="train+test", download=True) diff --git a/src/pytorch_metric_learning/utils/common_functions.py b/src/pytorch_metric_learning/utils/common_functions.py index cb95ebf7..424f00a3 100644 --- a/src/pytorch_metric_learning/utils/common_functions.py +++ b/src/pytorch_metric_learning/utils/common_functions.py @@ -3,10 +3,12 @@ import logging import os import re +import urllib import numpy as np import scipy.stats import torch +from tqdm import tqdm LOGGER_NAME = "PML" LOGGER = logging.getLogger(LOGGER_NAME) @@ -14,6 +16,18 @@ COLLECT_STATS = False +# taken from: +# https://github.com/pytorch/vision/blob/main/torchvision/datasets/utils.py#L27 +def _urlretrieve(url, filename, chunk_size=1024 * 32): + with urllib.request.urlopen(urllib.request.Request(url)) as response: + with open(filename, "wb") as fh, tqdm( + total=response.length, unit="B", unit_scale=True + ) as pbar: + while chunk := response.read(chunk_size): + fh.write(chunk) + pbar.update(len(chunk)) + + def set_logger_name(name): global LOGGER_NAME global LOGGER diff --git a/tests/datasets/__init__.py b/tests/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datasets/test_cars196.py b/tests/datasets/test_cars196.py new file mode 100644 index 00000000..43d1b4d9 --- /dev/null +++ b/tests/datasets/test_cars196.py @@ -0,0 +1,54 @@ +import os +import shutil +import unittest + +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from pytorch_metric_learning.datasets.cars196 import Cars196 + + +class TestCars196(unittest.TestCase): + + CARS_196_ROOT = "test_cars196" + ALREADY_EXISTS = False + + # In the rare case the user has an already existing directory, do not delete it + @classmethod + def setUpClass(cls): + if os.path.exists(cls.CARS_196_ROOT): + cls.ALREADY_EXISTS = True + + def test_Cars196(self): + train_test_data = Cars196( + root=TestCars196.CARS_196_ROOT, split="train+test", download=True + ) + train_data = Cars196( + root=TestCars196.CARS_196_ROOT, split="train", download=True + ) + test_data = Cars196( + root=TestCars196.CARS_196_ROOT, split="test", download=False + ) + + self.assertTrue(len(train_test_data) == 16185) + self.assertTrue(len(train_data) == 8054) + self.assertTrue(len(test_data) == 8131) + + def test_CARS_196_dataloader(self): + test_data = Cars196( + root=TestCars196.CARS_196_ROOT, + transform=transforms.Compose( + [transforms.Resize(size=(224, 224)), transforms.PILToTensor()] + ), + split="test", + download=True, + ) + loader = DataLoader(test_data, batch_size=8) + inputs, labels = next(iter(loader)) + self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224)) + self.assertTupleEqual(tuple(labels.shape), (8,)) + + @classmethod + def tearDownClass(cls): + if not cls.ALREADY_EXISTS: + shutil.rmtree(cls.CARS_196_ROOT) diff --git a/tests/datasets/test_cub.py b/tests/datasets/test_cub.py new file mode 100644 index 00000000..625e97d2 --- /dev/null +++ b/tests/datasets/test_cub.py @@ -0,0 +1,48 @@ +import os +import shutil +import unittest + +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from pytorch_metric_learning.datasets.cub import CUB + + +class TestCUB(unittest.TestCase): + + CUB_ROOT = "test_cub" + ALREADY_EXISTS = False + + # In the rare case the user has an already existing directory, do not delete it + @classmethod + def setUpClass(cls): + if os.path.exists(cls.CUB_ROOT): + cls.ALREADY_EXISTS = True + + def test_CUB(self): + train_test_data = CUB(root=TestCUB.CUB_ROOT, split="train+test", download=True) + train_data = CUB(root=TestCUB.CUB_ROOT, split="train", download=True) + test_data = CUB(root=TestCUB.CUB_ROOT, split="test", download=False) + + self.assertTrue(len(train_test_data) == 11788) + self.assertTrue(len(train_data) == 5864) + self.assertTrue(len(test_data) == 5924) + + def test_CUB_dataloader(self): + test_data = CUB( + root=TestCUB.CUB_ROOT, + transform=transforms.Compose( + [transforms.Resize(size=(224, 224)), transforms.PILToTensor()] + ), + split="test", + download=True, + ) + loader = DataLoader(test_data, batch_size=8) + inputs, labels = next(iter(loader)) + self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224)) + self.assertTupleEqual(tuple(labels.shape), (8,)) + + @classmethod + def tearDownClass(cls): + if not cls.ALREADY_EXISTS: + shutil.rmtree(cls.CUB_ROOT) diff --git a/tests/datasets/test_inaturalist2018.py b/tests/datasets/test_inaturalist2018.py new file mode 100644 index 00000000..123dfd64 --- /dev/null +++ b/tests/datasets/test_inaturalist2018.py @@ -0,0 +1,56 @@ +import os +import shutil +import unittest + +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from pytorch_metric_learning.datasets.inaturalist2018 import INaturalist2018 + + +class TestINaturalist2018(unittest.TestCase): + + INATURALIST2018_ROOT = "data" + ALREADY_EXISTS = False + + # In the rare case the user has an already existing directory, do not delete it + @classmethod + def setUpClass(cls): + if os.path.exists(cls.INATURALIST2018_ROOT): + cls.ALREADY_EXISTS = True + + def test_INaturalist2018(self): + train_test_data = INaturalist2018( + root=TestINaturalist2018.INATURALIST2018_ROOT, + split="train+test", + download=True, + ) + train_data = INaturalist2018( + root=TestINaturalist2018.INATURALIST2018_ROOT, split="train", download=True + ) + test_data = INaturalist2018( + root=TestINaturalist2018.INATURALIST2018_ROOT, split="test", download=False + ) + + self.assertTrue(len(train_test_data) == 461939) + self.assertTrue(len(train_data) == 325846) + self.assertTrue(len(test_data) == 136093) + + def test_INaturalist2018_dataloader(self): + test_data = INaturalist2018( + root=TestINaturalist2018.INATURALIST2018_ROOT, + transform=transforms.Compose( + [transforms.Resize(size=(224, 224)), transforms.PILToTensor()] + ), + split="test", + download=True, + ) + loader = DataLoader(test_data, batch_size=8) + inputs, labels = next(iter(loader)) + self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224)) + self.assertTupleEqual(tuple(labels.shape), (8,)) + + @classmethod + def tearDownClass(cls): + if not cls.ALREADY_EXISTS: + shutil.rmtree(cls.INATURALIST2018_ROOT) diff --git a/tests/datasets/test_sop.py b/tests/datasets/test_sop.py new file mode 100644 index 00000000..77142251 --- /dev/null +++ b/tests/datasets/test_sop.py @@ -0,0 +1,54 @@ +import os +import shutil +import unittest + +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from pytorch_metric_learning.datasets.sop import StanfordOnlineProducts + + +class TestStanfordOnlineProducts(unittest.TestCase): + + SOP_ROOT = "test_sop" + ALREADY_EXISTS = False + + # In the rare case the user has an already existing directory, do not delete it + @classmethod + def setUpClass(cls): + if os.path.exists(cls.SOP_ROOT): + cls.ALREADY_EXISTS = True + + def test_SOP(self): + train_test_data = StanfordOnlineProducts( + root=TestStanfordOnlineProducts.SOP_ROOT, split="train+test", download=True + ) + train_data = StanfordOnlineProducts( + root=TestStanfordOnlineProducts.SOP_ROOT, split="train", download=True + ) + test_data = StanfordOnlineProducts( + root=TestStanfordOnlineProducts.SOP_ROOT, split="test", download=False + ) + + self.assertTrue(len(train_test_data) == 120053) + self.assertTrue(len(train_data) == 59551) + self.assertTrue(len(test_data) == 60502) + + def test_SOP_dataloader(self): + test_data = StanfordOnlineProducts( + root=TestStanfordOnlineProducts.SOP_ROOT, + transform=transforms.Compose( + [transforms.Resize(size=(224, 224)), transforms.PILToTensor()] + ), + split="test", + download=True, + ) + loader = DataLoader(test_data, batch_size=8) + inputs, labels = next(iter(loader)) + self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224)) + self.assertTupleEqual(tuple(labels.shape), (8,)) + + @classmethod + def tearDownClass(cls): + if not cls.ALREADY_EXISTS: + shutil.rmtree(cls.SOP_ROOT)