From e8a8466ad53559e69295d4f90f879f8b62abd658 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sun, 20 Oct 2024 19:08:46 +0200 Subject: [PATCH 01/20] add cub --- src/pytorch_metric_learning/datasets/cub.py | 77 +++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 src/pytorch_metric_learning/datasets/cub.py diff --git a/src/pytorch_metric_learning/datasets/cub.py b/src/pytorch_metric_learning/datasets/cub.py new file mode 100644 index 00000000..0f303a76 --- /dev/null +++ b/src/pytorch_metric_learning/datasets/cub.py @@ -0,0 +1,77 @@ +from PIL import Image +from torch.utils.data import Dataset +import os + +class CUB(Dataset): + + SPLITS = ["train", "test", "train+test"] + DOWNLOAD_URL = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" + + def __init__(self, root, split="train+test", transform=None, target_transform=None, download=False): + self.root = root + + if download and not os.path.isdir(self.root): + archive_name = CUB.DOWNLOAD_URL.split('/')[-1] + os.makedirs(self.root, exist_ok=False) + os.system(f"wget -P {self.root} {CUB.DOWNLOAD_URL}") + os.system(f"cd {self.root} && tar -xzvf {archive_name}") + os.system(f"rm {os.path.join(self.root, archive_name)}") + else: + # The given directory does not exist so the user should be aware of downloading it + # Otherwise proceed as usual + if not os.path.isdir(self.root): + raise ValueError("The given path does not exist. " + "You should probably initialize the dataset with download=True." + ) + + self.transform = transform + self.target_transform = target_transform + + if split not in CUB.SPLITS: + raise ValueError(f"Supported splits are: {', '.join(CUB.SPLITS)}") + + self.split = split + + dir_name = CUB.DOWNLOAD_URL.split('/')[-1].replace(".tgz", "") + + # Training split is first 100 classes, other 100 is test + if self.split == "train": + classes = set(range(1, 101)) + elif self.split == "test": + classes = set(range(101, 201)) + else: + classes = set(range(1, 201)) + + # Find ids which correspond to the classes in the split + self.paths, self.labels = [], [] + with open(os.path.join(self.root, dir_name, "image_class_labels.txt")) as f1: + with open(os.path.join(self.root, dir_name, "images.txt")) as f2: + for l1, l2 in zip(f1, f2): + img_idx1, class_idx = list(map(int, l1.split())) + img_idx2, img_path = l2.split() + img_idx2 = int(img_idx2) + + # If the image ids correspond it's a match + if img_idx1 == img_idx2: + self.paths.append(img_path) + self.labels.append(class_idx) + + assert len(self.paths) == len(self.labels) == 11788 + + # Normalize labels to start from 0 + self.labels = [x - min(self.labels) for x in self.labels] + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + img = Image.open(self.paths[idx]) + label = self.labels[idx] + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + label = self.target_transform(label) + + return (img, label) From 98b58973d16f7f615114be3f55dbd841bedaca66 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sun, 20 Oct 2024 19:52:10 +0200 Subject: [PATCH 02/20] convert to base and cub --- .../datasets/base_dataset.py | 58 +++++++++++++++++ src/pytorch_metric_learning/datasets/cub.py | 63 ++++++------------- 2 files changed, 78 insertions(+), 43 deletions(-) create mode 100644 src/pytorch_metric_learning/datasets/base_dataset.py diff --git a/src/pytorch_metric_learning/datasets/base_dataset.py b/src/pytorch_metric_learning/datasets/base_dataset.py new file mode 100644 index 00000000..5a4846a4 --- /dev/null +++ b/src/pytorch_metric_learning/datasets/base_dataset.py @@ -0,0 +1,58 @@ +from PIL import Image +from torch.utils.data import Dataset +import os +from abc import ABC, abstractmethod + +class BaseDataset(ABC, Dataset): + + def __init__(self, root, split="train+test", transform=None, target_transform=None, download=False): + self.root = root + + if download and not os.path.isdir(self.root): + os.makedirs(self.root, exist_ok=False) + self.download_and_remove() + else: + # The given directory does not exist so the user should be aware of downloading it + # Otherwise proceed as usual + if not os.path.isdir(self.root): + raise ValueError("The given path does not exist. " + "You should probably initialize the dataset with download=True." + ) + + self.transform = transform + self.target_transform = target_transform + + if split not in self.get_available_splits(): + raise ValueError(f"Supported splits are: {', '.join(self.get_available_splits())}") + + self.split = split + + @staticmethod + @abstractmethod + def download_and_remove(): + pass + + @staticmethod + @abstractmethod + def get_available_splits(): + pass + + @staticmethod + @abstractmethod + def get_download_url(): + pass + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + img = Image.open(self.paths[idx]) + label = self.labels[idx] + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + label = self.target_transform(label) + + return (img, label) \ No newline at end of file diff --git a/src/pytorch_metric_learning/datasets/cub.py b/src/pytorch_metric_learning/datasets/cub.py index 0f303a76..b1ba987f 100644 --- a/src/pytorch_metric_learning/datasets/cub.py +++ b/src/pytorch_metric_learning/datasets/cub.py @@ -1,38 +1,12 @@ -from PIL import Image -from torch.utils.data import Dataset +from ..datasets.base_dataset import BaseDataset import os -class CUB(Dataset): +class CUB(BaseDataset): - SPLITS = ["train", "test", "train+test"] - DOWNLOAD_URL = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def __init__(self, root, split="train+test", transform=None, target_transform=None, download=False): - self.root = root - - if download and not os.path.isdir(self.root): - archive_name = CUB.DOWNLOAD_URL.split('/')[-1] - os.makedirs(self.root, exist_ok=False) - os.system(f"wget -P {self.root} {CUB.DOWNLOAD_URL}") - os.system(f"cd {self.root} && tar -xzvf {archive_name}") - os.system(f"rm {os.path.join(self.root, archive_name)}") - else: - # The given directory does not exist so the user should be aware of downloading it - # Otherwise proceed as usual - if not os.path.isdir(self.root): - raise ValueError("The given path does not exist. " - "You should probably initialize the dataset with download=True." - ) - - self.transform = transform - self.target_transform = target_transform - - if split not in CUB.SPLITS: - raise ValueError(f"Supported splits are: {', '.join(CUB.SPLITS)}") - - self.split = split - - dir_name = CUB.DOWNLOAD_URL.split('/')[-1].replace(".tgz", "") + dir_name = self.get_download_url().split('/')[-1].replace(".tgz", "") # Training split is first 100 classes, other 100 is test if self.split == "train": @@ -48,6 +22,10 @@ def __init__(self, root, split="train+test", transform=None, target_transform=No with open(os.path.join(self.root, dir_name, "images.txt")) as f2: for l1, l2 in zip(f1, f2): img_idx1, class_idx = list(map(int, l1.split())) + + if class_idx not in classes: + continue + img_idx2, img_path = l2.split() img_idx2 = int(img_idx2) @@ -61,17 +39,16 @@ def __init__(self, root, split="train+test", transform=None, target_transform=No # Normalize labels to start from 0 self.labels = [x - min(self.labels) for x in self.labels] - def __len__(self): - return len(self.labels) - - def __getitem__(self, idx): - img = Image.open(self.paths[idx]) - label = self.labels[idx] - - if self.transform is not None: - img = self.transform(img) + def download_and_remove(self): + archive_name = self.get_download_url().split('/')[-1] + os.system(f"wget -P {self.root} {self.get_download_url()}") + os.system(f"cd {self.root} && tar -xzvf {archive_name}") + os.system(f"rm {os.path.join(self.root, archive_name)}") - if self.target_transform is not None: - label = self.target_transform(label) + @staticmethod + def get_available_splits(): + return ["train", "test", "train+test"] - return (img, label) + @staticmethod + def get_download_url(): + return "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" \ No newline at end of file From 4939740dde740883040877ab02e2848e530440d5 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sun, 20 Oct 2024 20:37:47 +0200 Subject: [PATCH 03/20] add cars with disjoint split --- .../datasets/cars196.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/pytorch_metric_learning/datasets/cars196.py diff --git a/src/pytorch_metric_learning/datasets/cars196.py b/src/pytorch_metric_learning/datasets/cars196.py new file mode 100644 index 00000000..b0c6696f --- /dev/null +++ b/src/pytorch_metric_learning/datasets/cars196.py @@ -0,0 +1,57 @@ +from ..datasets.base_dataset import BaseDataset +import os + +class Cars196(BaseDataset): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Training set is first 99 classes, test is other classes + if self.split == "train": + classes = set(range(1, 99)) + elif self.split == "test": + classes = set(range(99, 197)) + else: + classes = set(range(1, 197)) + + paths_train, labels_train = self._load_csv( + os.path.join(self.root, "anno_train.csv"), split="train" + ) + paths_test, labels_test = self._load_csv( + os.path.join(self.root, "anno_test.csv"), split="test" + ) + paths = paths_train + paths_test + labels = labels_train + labels_test + + self.paths, self.labels = [], [] + for p, l in zip(paths, labels): + if l in classes: + self.paths.append(p) + self.labels.append(l) + + def _load_csv(self, path, split): + all_paths, all_labels = [], [] + with open(path, "r") as f: + for l in f: + path_annos = l.split(",") + curr_path = path_annos[0] + curr_label = path_annos[-1] + all_paths.append( + os.path.join(self.root, "car_data", "car_data", split, curr_path) + ) + all_labels.append(int(curr_label)) + return all_paths, all_labels + + def download_and_remove(self): + archive_name = self.get_download_url().split('/')[-1] + os.system(f"wget -P {self.root} {self.get_download_url()}") + os.system(f"cd {self.root} && unzip {archive_name}") + os.system(f"rm {os.path.join(self.root, archive_name)}") + + @staticmethod + def get_available_splits(): + return ["train", "test", "train+test"] + + @staticmethod + def get_download_url(): + return "https://www.kaggle.com/api/v1/datasets/download/jutrera/stanford-car-dataset-by-classes-folder" \ No newline at end of file From 3f14aae5f40dd255da9c7e1486674ce90c4706a7 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sat, 26 Oct 2024 00:28:32 +0200 Subject: [PATCH 04/20] added datasets docs page --- docs/datasets.md | 109 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 docs/datasets.md diff --git a/docs/datasets.md b/docs/datasets.md new file mode 100644 index 00000000..17d89e02 --- /dev/null +++ b/docs/datasets.md @@ -0,0 +1,109 @@ +# Datasets + +Datasets classes give you a way to automatically download a dataset and transform it into a PyTorch dataset. + +## BaseDataset + +All dataset classes extend this class and therefore inherit its `__init__` parameters. + +```python +datasets.base_dataset.BaseDataset( + root, + split="train+test", + transform=None, + target_transform=None, + download=False +) +``` + +**Parameters**: + +* **root**: The path where the dataset files are saved. +* **split**: A string that determines which split of the dataset is loaded. +* **transform**: A `torchvision.transforms` object which will be used on the input images. +* **target_transform**: A `torchvision.transforms` object which will be used on the labels. +* **download**: Whether to download the dataset or not. Setting this as False, but not having the dataset on the disk will raise a ValueError. + + +**Required Implementations**: +```python + @staticmethod + @abstractmethod + def download_and_remove(): + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_available_splits(): + raise NotImplementedError +``` + +## CUB-200-2011 + +```python +datasets.cub.CUB(*args, **kwargs) +``` + +**Defined splits**: + +- `train` - Consists of examples from classes 1 to 100. +- `test` - Consists of examples from classes 101 to 200. +- `train+test` - Consists of examples from all classes. + +**Loading different dataset splits** +```python +train_dataset = CUB(root="data", + split="train", + transform=None, + target_transform=None, + download=True +) +# No need to download the dataset after it is already downladed +test_dataset = CUB(root="data", + split="test", + transform=None, + target_transform=None, + download=False +) +train_and_test_dataset = CUB(root="data", + split="train+test", + transform=None, + target_transform=None, + download=False +) +``` + +## Cars196 + +```python +datasets.cars196.Cars196(*args, **kwargs) +``` + +**Defined splits**: + +- `train` - Consists of examples from classes 1 to 99. +- `test` - Consists of examples from classes 99 to 197. +- `train+test` - Consists of examples from all classes. + +**Loading different dataset splits** +```python +train_dataset = Cars196(root="data", + split="train", + transform=None, + target_transform=None, + download=True +) +# No need to download the dataset after it is already downladed +test_dataset = Cars196(root="data", + split="test", + transform=None, + target_transform=None, + download=False +) +train_and_test_dataset = Cars196(root="data", + split="train+test", + transform=None, + target_transform=None, + download=False +) + From c18aeb0e7fab5d3d4769d126091771331313a62e Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sat, 26 Oct 2024 00:36:57 +0200 Subject: [PATCH 05/20] add info on creating custom dataset --- docs/datasets.md | 4 ++-- docs/extend/datasets.md | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 docs/extend/datasets.md diff --git a/docs/datasets.md b/docs/datasets.md index 17d89e02..f0886d11 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -4,7 +4,7 @@ Datasets classes give you a way to automatically download a dataset and transfor ## BaseDataset -All dataset classes extend this class and therefore inherit its `__init__` parameters. +All dataset classes extend this class and therefore inherit its ```__init__``` parameters. ```python datasets.base_dataset.BaseDataset( @@ -106,4 +106,4 @@ train_and_test_dataset = Cars196(root="data", target_transform=None, download=False ) - +``` diff --git a/docs/extend/datasets.md b/docs/extend/datasets.md new file mode 100644 index 00000000..7ddd9548 --- /dev/null +++ b/docs/extend/datasets.md @@ -0,0 +1,31 @@ +# How to write custom datasets + +1. Subclass the ```datasets.base_dataset.BaseDatset``` class +2. Add implementations for abstract static methods from the base class: + - ```download_and_remove()``` + - ```get_available_splits()``` + + +```python +from pytorch_metric_learning.datasets.base_dataset import BaseDataset + +class MyDataset(BaseDataset): + + def __init__(self, my_parameter, *args, **kwargs): + super().__init__(*args, **kwargs) + self.my_parameter = self.parameter + + @staticmethod + def download_and_remove(): + # Downloads the dataset files needed + # + # If you're using a dataset that you've already downloaded elsewhere, + # just use an empty implementation + pass + + @staticmethod + def get_available_splits(): + # Returns the string names of the available splits + return ["my_split1", "my_split2"] + +``` From cdffe5022d1f8d9398a759bdaf89f55aa4b634f0 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sat, 26 Oct 2024 04:18:46 +0200 Subject: [PATCH 06/20] refactor --- .../datasets/base_dataset.py | 31 +++++++++-------- .../datasets/cars196.py | 29 ++++++++-------- src/pytorch_metric_learning/datasets/cub.py | 33 +++++++++---------- 3 files changed, 48 insertions(+), 45 deletions(-) diff --git a/src/pytorch_metric_learning/datasets/base_dataset.py b/src/pytorch_metric_learning/datasets/base_dataset.py index 5a4846a4..d91eb4c0 100644 --- a/src/pytorch_metric_learning/datasets/base_dataset.py +++ b/src/pytorch_metric_learning/datasets/base_dataset.py @@ -8,9 +8,14 @@ class BaseDataset(ABC, Dataset): def __init__(self, root, split="train+test", transform=None, target_transform=None, download=False): self.root = root - if download and not os.path.isdir(self.root): - os.makedirs(self.root, exist_ok=False) - self.download_and_remove() + if download: + if not os.path.isdir(self.root): + os.makedirs(self.root, exist_ok=False) + self.download_and_remove() + elif os.listdir(self.root) == []: + self.download_and_remove() + else: + raise ValueError("The given directory exists and is not empty.") else: # The given directory does not exist so the user should be aware of downloading it # Otherwise proceed as usual @@ -27,20 +32,18 @@ def __init__(self, root, split="train+test", transform=None, target_transform=No self.split = split - @staticmethod - @abstractmethod - def download_and_remove(): - pass + self.generate_split() - @staticmethod @abstractmethod - def get_available_splits(): - pass - - @staticmethod + def generate_split(): + raise NotImplementedError + @abstractmethod - def get_download_url(): - pass + def download_and_remove(): + raise NotImplementedError + + def get_available_splits(self): + return ["train", "test", "train+test"] def __len__(self): return len(self.labels) diff --git a/src/pytorch_metric_learning/datasets/cars196.py b/src/pytorch_metric_learning/datasets/cars196.py index b0c6696f..5123b799 100644 --- a/src/pytorch_metric_learning/datasets/cars196.py +++ b/src/pytorch_metric_learning/datasets/cars196.py @@ -1,11 +1,13 @@ from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve import os +import zipfile class Cars196(BaseDataset): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + DOWNLOAD_URL = "https://www.kaggle.com/api/v1/datasets/download/jutrera/stanford-car-dataset-by-classes-folder" + def generate_split(self): # Training set is first 99 classes, test is other classes if self.split == "train": classes = set(range(1, 99)) @@ -43,15 +45,14 @@ def _load_csv(self, path, split): return all_paths, all_labels def download_and_remove(self): - archive_name = self.get_download_url().split('/')[-1] - os.system(f"wget -P {self.root} {self.get_download_url()}") - os.system(f"cd {self.root} && unzip {archive_name}") - os.system(f"rm {os.path.join(self.root, archive_name)}") - - @staticmethod - def get_available_splits(): - return ["train", "test", "train+test"] - - @staticmethod - def get_download_url(): - return "https://www.kaggle.com/api/v1/datasets/download/jutrera/stanford-car-dataset-by-classes-folder" \ No newline at end of file + os.makedirs(self.root, exist_ok=True) + download_folder_path = os.path.join(self.root, Cars196.DOWNLOAD_URL.split('/')[-1]) + _urlretrieve(url=Cars196.DOWNLOAD_URL, filename=download_folder_path) + with zipfile.ZipFile(download_folder_path, 'r') as zip_ref: + zip_ref.extractall(self.root) + os.remove(download_folder_path) + +# if __name__ == "__main__": +# train_dataset = Cars196(root="data_cars", split="train+test", download=True) +# test_dataset = Cars196(root="data_cars", split="test", download=True) +# print(len(train_dataset), len(test_dataset)) \ No newline at end of file diff --git a/src/pytorch_metric_learning/datasets/cub.py b/src/pytorch_metric_learning/datasets/cub.py index b1ba987f..708c3a93 100644 --- a/src/pytorch_metric_learning/datasets/cub.py +++ b/src/pytorch_metric_learning/datasets/cub.py @@ -1,12 +1,14 @@ from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve import os +import tarfile class CUB(BaseDataset): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + DOWNLOAD_URL = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" - dir_name = self.get_download_url().split('/')[-1].replace(".tgz", "") + def generate_split(self): + dir_name = CUB.DOWNLOAD_URL.split('/')[-1].replace(".tgz", "") # Training split is first 100 classes, other 100 is test if self.split == "train": @@ -34,21 +36,18 @@ def __init__(self, *args, **kwargs): self.paths.append(img_path) self.labels.append(class_idx) - assert len(self.paths) == len(self.labels) == 11788 - # Normalize labels to start from 0 self.labels = [x - min(self.labels) for x in self.labels] def download_and_remove(self): - archive_name = self.get_download_url().split('/')[-1] - os.system(f"wget -P {self.root} {self.get_download_url()}") - os.system(f"cd {self.root} && tar -xzvf {archive_name}") - os.system(f"rm {os.path.join(self.root, archive_name)}") - - @staticmethod - def get_available_splits(): - return ["train", "test", "train+test"] - - @staticmethod - def get_download_url(): - return "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" \ No newline at end of file + os.makedirs(self.root, exist_ok=True) + download_folder_path = os.path.join(self.root, CUB.DOWNLOAD_URL.split('/')[-1]) + _urlretrieve(url=CUB.DOWNLOAD_URL, filename=download_folder_path) + with tarfile.open(download_folder_path, "r:gz") as tar: + tar.extractall(self.root) + os.remove(download_folder_path) + + +# if __name__ == "__main__": +# train_dataset = CUB(root="data_cub", split="test", download=True) +# print(len(train_dataset)) \ No newline at end of file From 8808fb2ba67d5a3d6e1b4d5fb609dc9916444aae Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sat, 26 Oct 2024 04:19:36 +0200 Subject: [PATCH 07/20] add pretty download function --- src/pytorch_metric_learning/utils/common_functions.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/pytorch_metric_learning/utils/common_functions.py b/src/pytorch_metric_learning/utils/common_functions.py index cb95ebf7..3b07803f 100644 --- a/src/pytorch_metric_learning/utils/common_functions.py +++ b/src/pytorch_metric_learning/utils/common_functions.py @@ -3,16 +3,26 @@ import logging import os import re +import urllib import numpy as np import scipy.stats import torch +from tqdm import tqdm LOGGER_NAME = "PML" LOGGER = logging.getLogger(LOGGER_NAME) NUMPY_RANDOM = np.random COLLECT_STATS = False +# taken from: +# https://github.com/pytorch/vision/blob/main/torchvision/datasets/utils.py#L27 +def _urlretrieve(url, filename, chunk_size=1024 * 32): + with urllib.request.urlopen(urllib.request.Request(url)) as response: + with open(filename, "wb") as fh, tqdm(total=response.length, unit="B", unit_scale=True) as pbar: + while chunk := response.read(chunk_size): + fh.write(chunk) + pbar.update(len(chunk)) def set_logger_name(name): global LOGGER_NAME From 324622d7fb6ea62bf941745eaf1c6072703583e5 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sat, 26 Oct 2024 04:19:55 +0200 Subject: [PATCH 08/20] add inaturalist --- .../datasets/inaturalist2018.py | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 src/pytorch_metric_learning/datasets/inaturalist2018.py diff --git a/src/pytorch_metric_learning/datasets/inaturalist2018.py b/src/pytorch_metric_learning/datasets/inaturalist2018.py new file mode 100644 index 00000000..9a56b4fb --- /dev/null +++ b/src/pytorch_metric_learning/datasets/inaturalist2018.py @@ -0,0 +1,81 @@ +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve +import os +import tarfile +import zipfile +import json + +class INaturalist2018(BaseDataset): + + IMG_DOWNLOAD_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz" + TRAIN_ANN_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train2018.json.tar.gz" + VAL_ANN_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/val2018.json.tar.gz" + SPLITS_URL = "https://drive.google.com/uc?id=1sXfkBTFDrRU3__-NUs1qBP3sf_0uMB98" + + def generate_split(self): + train_json = json.load(open(os.path.join(self.root, "train2018.json"))) + val_json = json.load(open(os.path.join(self.root, "val2018.json"))) + + val_imgs, val_anns = val_json["images"], val_json["annotations"] + train_imgs, train_anns = train_json["images"], train_json["annotations"] + + imgs, anns = val_imgs + train_imgs, val_anns + train_anns + + path2id = {x["file_name"]:x["id"] for x in imgs} + id2label = {x["image_id"]:x["category_id"] for x in anns} + + if self.split in ["train", "test"]: + paths = self._load_split_txt(self.split) + ids = [path2id[p] for p in paths] + labels = [id2label[i] for i in ids] + + elif self.split == "train+test": + train_paths = self._load_split_txt("train") + train_ids = [path2id[p] for p in train_paths] + train_labels = [id2label[i] for i in train_ids] + + test_paths = self._load_split_txt("test") + test_ids = [path2id[p] for p in test_paths] + test_labels = [id2label[i] for i in test_ids] + + paths = train_paths + test_paths + labels = train_labels + test_labels + + self.paths = paths + self.labels = labels + + def _load_split_txt(self, split): + paths = [] + with open(os.path.join(self.root, "Inat_dataset_splits", f"Inaturalist_{split}_set1.txt")) as f: + for l in f: + paths.append(l.strip()) + return paths + + def download_and_remove(self): + download_folder_path = os.path.join(self.root, INaturalist2018.IMG_DOWNLOAD_URL.split('/')[-1]) + _urlretrieve(url=INaturalist2018.IMG_DOWNLOAD_URL, filename=download_folder_path) + with tarfile.open(download_folder_path, "r:gz") as tar: + tar.extractall(self.root) + os.remove(download_folder_path) + + download_folder_path = os.path.join(self.root, INaturalist2018.TRAIN_ANN_URL.split('/')[-1]) + _urlretrieve(url=INaturalist2018.TRAIN_ANN_URL, filename=download_folder_path) + with tarfile.open(download_folder_path, "r:gz") as tar: + tar.extractall(self.root) + os.remove(download_folder_path) + + download_folder_path = os.path.join(self.root, INaturalist2018.VAL_ANN_URL.split('/')[-1]) + _urlretrieve(url=INaturalist2018.VAL_ANN_URL, filename=download_folder_path) + with tarfile.open(download_folder_path, "r:gz") as tar: + tar.extractall(self.root) + os.remove(download_folder_path) + + download_folder_path = os.path.join(self.root, INaturalist2018.SPLITS_URL.split('/')[-1]) + _urlretrieve(url=INaturalist2018.SPLITS_URL, filename=download_folder_path) + with zipfile.ZipFile(download_folder_path, "r") as zip_ref: + zip_ref.extractall(self.root) + os.remove(download_folder_path) + + +# if __name__ == "__main__": +# train_dataset = INaturalist2018(root="data", split="train+test", download=False) \ No newline at end of file From 0c713603404bdbb3c5fd7c9f1f55fc3ab6d28e59 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sat, 26 Oct 2024 04:37:22 +0200 Subject: [PATCH 09/20] update docs --- docs/datasets.md | 55 +++++++++++++++---- docs/extend/datasets.md | 23 ++++---- .../datasets/cars196.py | 7 ++- .../datasets/inaturalist2018.py | 8 ++- 4 files changed, 67 insertions(+), 26 deletions(-) diff --git a/docs/datasets.md b/docs/datasets.md index f0886d11..31a22559 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -2,6 +2,8 @@ Datasets classes give you a way to automatically download a dataset and transform it into a PyTorch dataset. +All implemented datasets have disjoint train-test splits, ideal for benchmarking on image retrieval and one-shot/few-shot classification tasks. + ## BaseDataset All dataset classes extend this class and therefore inherit its ```__init__``` parameters. @@ -24,17 +26,14 @@ datasets.base_dataset.BaseDataset( * **target_transform**: A `torchvision.transforms` object which will be used on the labels. * **download**: Whether to download the dataset or not. Setting this as False, but not having the dataset on the disk will raise a ValueError. - **Required Implementations**: ```python - @staticmethod @abstractmethod def download_and_remove(): raise NotImplementedError - @staticmethod @abstractmethod - def get_available_splits(): + def generate_split(): raise NotImplementedError ``` @@ -46,9 +45,9 @@ datasets.cub.CUB(*args, **kwargs) **Defined splits**: -- `train` - Consists of examples from classes 1 to 100. -- `test` - Consists of examples from classes 101 to 200. -- `train+test` - Consists of examples from all classes. +- `train` - Consists of 5864 examples, taken from classes 1 to 100. +- `test` - Consists of 5924 examples, taken from classes 101 to 200. +- `train+test` - Consists 11788 of examples, taken from all classes. **Loading different dataset splits** ```python @@ -81,9 +80,9 @@ datasets.cars196.Cars196(*args, **kwargs) **Defined splits**: -- `train` - Consists of examples from classes 1 to 99. -- `test` - Consists of examples from classes 99 to 197. -- `train+test` - Consists of examples from all classes. +- `train` - Consists of 8054 examples, taken from classes 1 to 99. +- `test` - Consists of 8131 examples, taken from classes 99 to 197. +- `train+test` - Consists of 16185 examples, taken from all classes. **Loading different dataset splits** ```python @@ -107,3 +106,39 @@ train_and_test_dataset = Cars196(root="data", download=False ) ``` + +## INaturalist2018 + +```python +datasets.inaturalist2018.INaturalist2018(*args, **kwargs) +``` + +**Defined splits**: + +- `train` - Consists of 325 846 examples. +- `test` - Consists of 136 093 examples. +- `train+test` - Consists of 461 939 examples. + +**Loading different dataset splits** +```python +# The download takes a while - the dataset is very large +train_dataset = INaturalist2018(root="data", + split="train", + transform=None, + target_transform=None, + download=True +) +# No need to download the dataset after it is already downladed +test_dataset = INaturalist2018(root="data", + split="test", + transform=None, + target_transform=None, + download=False +) +train_and_test_dataset = INaturalist2018(root="data", + split="train+test", + transform=None, + target_transform=None, + download=False +) +``` diff --git a/docs/extend/datasets.md b/docs/extend/datasets.md index 7ddd9548..606fecab 100644 --- a/docs/extend/datasets.md +++ b/docs/extend/datasets.md @@ -1,9 +1,9 @@ # How to write custom datasets -1. Subclass the ```datasets.base_dataset.BaseDatset``` class -2. Add implementations for abstract static methods from the base class: +1. Subclass the ```datasets.base_dataset.BaseDataset``` class +2. Add implementations for abstract methods from the base class: - ```download_and_remove()``` - - ```get_available_splits()``` + - ```generate_split()``` ```python @@ -13,19 +13,22 @@ class MyDataset(BaseDataset): def __init__(self, my_parameter, *args, **kwargs): super().__init__(*args, **kwargs) - self.my_parameter = self.parameter + self.my_parameter = self.my_parameter - @staticmethod - def download_and_remove(): + def download_and_remove(self): # Downloads the dataset files needed # # If you're using a dataset that you've already downloaded elsewhere, # just use an empty implementation pass - @staticmethod - def get_available_splits(): - # Returns the string names of the available splits - return ["my_split1", "my_split2"] + def generate_split(self): + # Creates a list of image paths, and saves them into self.paths + # Creates a list of labels for the images, and saves them into self.labels + # + # The default training splits that need to be covered are `train`, `test`, and `train+test` + # If you need a different split setup, override `get_available_splits(self)` to return + # the split names you want + pass ``` diff --git a/src/pytorch_metric_learning/datasets/cars196.py b/src/pytorch_metric_learning/datasets/cars196.py index 5123b799..ce1da23c 100644 --- a/src/pytorch_metric_learning/datasets/cars196.py +++ b/src/pytorch_metric_learning/datasets/cars196.py @@ -53,6 +53,7 @@ def download_and_remove(self): os.remove(download_folder_path) # if __name__ == "__main__": -# train_dataset = Cars196(root="data_cars", split="train+test", download=True) -# test_dataset = Cars196(root="data_cars", split="test", download=True) -# print(len(train_dataset), len(test_dataset)) \ No newline at end of file +# train_dataset = Cars196(root="data_cars", split="train", download=False) +# test_dataset = Cars196(root="data_cars", split="test", download=False) +# train_test_dataset = Cars196(root="data_cars", split="train+test", download=False) +# print(len(train_dataset), len(test_dataset), len(train_test_dataset)) \ No newline at end of file diff --git a/src/pytorch_metric_learning/datasets/inaturalist2018.py b/src/pytorch_metric_learning/datasets/inaturalist2018.py index 9a56b4fb..17d897e6 100644 --- a/src/pytorch_metric_learning/datasets/inaturalist2018.py +++ b/src/pytorch_metric_learning/datasets/inaturalist2018.py @@ -28,7 +28,7 @@ def generate_split(self): paths = self._load_split_txt(self.split) ids = [path2id[p] for p in paths] labels = [id2label[i] for i in ids] - + elif self.split == "train+test": train_paths = self._load_split_txt("train") train_ids = [path2id[p] for p in train_paths] @@ -76,6 +76,8 @@ def download_and_remove(self): zip_ref.extractall(self.root) os.remove(download_folder_path) - # if __name__ == "__main__": -# train_dataset = INaturalist2018(root="data", split="train+test", download=False) \ No newline at end of file +# train_test_dataset = INaturalist2018(root="data", split="train+test", download=False) +# train_dataset = INaturalist2018(root="data", split="train", download=False) +# test_dataset = INaturalist2018(root="data", split="test", download=False) +# print(len(train_test_dataset), len(train_dataset), len(test_dataset)) \ No newline at end of file From 5e8237925ab2bf10a2c2eacbc1cfcdaabeed20bf Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sat, 26 Oct 2024 20:33:43 +0200 Subject: [PATCH 10/20] add stanford online products --- docs/datasets.md | 36 ++++++++++++++ .../datasets/base_dataset.py | 2 - .../datasets/cars196.py | 4 +- src/pytorch_metric_learning/datasets/cub.py | 3 -- .../datasets/inaturalist2018.py | 6 +-- src/pytorch_metric_learning/datasets/sop.py | 48 +++++++++++++++++++ 6 files changed, 89 insertions(+), 10 deletions(-) create mode 100644 src/pytorch_metric_learning/datasets/sop.py diff --git a/docs/datasets.md b/docs/datasets.md index 31a22559..b698370f 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -142,3 +142,39 @@ train_and_test_dataset = INaturalist2018(root="data", download=False ) ``` + +## StanfordOnlineProducts + +```python +datasets.sop.StanfordOnlineProducts(*args, **kwargs) +``` + +**Defined splits**: + +- `train` - Consists of 59551 examples. +- `test` - Consists of 60502 examples. +- `train+test` - Consists of 120 053 examples. + +**Loading different dataset splits** +```python +# The download takes a while - the dataset is very large +train_dataset = StanfordOnlineProducts(root="data", + split="train", + transform=None, + target_transform=None, + download=True +) +# No need to download the dataset after it is already downladed +test_dataset = StanfordOnlineProducts(root="data", + split="test", + transform=None, + target_transform=None, + download=False +) +train_and_test_dataset = StanfordOnlineProducts(root="data", + split="train+test", + transform=None, + target_transform=None, + download=False +) +``` diff --git a/src/pytorch_metric_learning/datasets/base_dataset.py b/src/pytorch_metric_learning/datasets/base_dataset.py index d91eb4c0..1af185ca 100644 --- a/src/pytorch_metric_learning/datasets/base_dataset.py +++ b/src/pytorch_metric_learning/datasets/base_dataset.py @@ -14,8 +14,6 @@ def __init__(self, root, split="train+test", transform=None, target_transform=No self.download_and_remove() elif os.listdir(self.root) == []: self.download_and_remove() - else: - raise ValueError("The given directory exists and is not empty.") else: # The given directory does not exist so the user should be aware of downloading it # Otherwise proceed as usual diff --git a/src/pytorch_metric_learning/datasets/cars196.py b/src/pytorch_metric_learning/datasets/cars196.py index ce1da23c..7cefc20d 100644 --- a/src/pytorch_metric_learning/datasets/cars196.py +++ b/src/pytorch_metric_learning/datasets/cars196.py @@ -53,7 +53,7 @@ def download_and_remove(self): os.remove(download_folder_path) # if __name__ == "__main__": -# train_dataset = Cars196(root="data_cars", split="train", download=False) -# test_dataset = Cars196(root="data_cars", split="test", download=False) +# train_dataset = Cars196(root="data_cars", split="train", download=True) +# test_dataset = Cars196(root="data_cars", split="test", download=True) # train_test_dataset = Cars196(root="data_cars", split="train+test", download=False) # print(len(train_dataset), len(test_dataset), len(train_test_dataset)) \ No newline at end of file diff --git a/src/pytorch_metric_learning/datasets/cub.py b/src/pytorch_metric_learning/datasets/cub.py index 708c3a93..104b18f4 100644 --- a/src/pytorch_metric_learning/datasets/cub.py +++ b/src/pytorch_metric_learning/datasets/cub.py @@ -36,9 +36,6 @@ def generate_split(self): self.paths.append(img_path) self.labels.append(class_idx) - # Normalize labels to start from 0 - self.labels = [x - min(self.labels) for x in self.labels] - def download_and_remove(self): os.makedirs(self.root, exist_ok=True) download_folder_path = os.path.join(self.root, CUB.DOWNLOAD_URL.split('/')[-1]) diff --git a/src/pytorch_metric_learning/datasets/inaturalist2018.py b/src/pytorch_metric_learning/datasets/inaturalist2018.py index 17d897e6..3d3cd004 100644 --- a/src/pytorch_metric_learning/datasets/inaturalist2018.py +++ b/src/pytorch_metric_learning/datasets/inaturalist2018.py @@ -77,7 +77,7 @@ def download_and_remove(self): os.remove(download_folder_path) # if __name__ == "__main__": -# train_test_dataset = INaturalist2018(root="data", split="train+test", download=False) -# train_dataset = INaturalist2018(root="data", split="train", download=False) -# test_dataset = INaturalist2018(root="data", split="test", download=False) +# train_test_dataset = INaturalist2018(root="data", split="train+test", download=True) +# train_dataset = INaturalist2018(root="data", split="train", download=True) +# test_dataset = INaturalist2018(root="data", split="test", download=True) # print(len(train_test_dataset), len(train_dataset), len(test_dataset)) \ No newline at end of file diff --git a/src/pytorch_metric_learning/datasets/sop.py b/src/pytorch_metric_learning/datasets/sop.py new file mode 100644 index 00000000..0a53ef62 --- /dev/null +++ b/src/pytorch_metric_learning/datasets/sop.py @@ -0,0 +1,48 @@ +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve +import os +import zipfile + +class StanfordOnlineProducts(BaseDataset): + + DOWNLOAD_URL = "https://drive.usercontent.google.com/download?id=1TclrpQOF_ullUP99wk_gjGN8pKvtErG8&export=download&authuser=0&confirm=t" + + def generate_split(self): + if self.split in ["train", "test"]: + paths, labels = self._load_split_txt(self.split) + elif self.split == "train+test": + train_paths, train_labels = self._load_split_txt("train") + test_paths, test_labels = self._load_split_txt("test") + paths = train_paths + test_paths + labels = train_labels + test_labels + + self.paths = paths + self.labels = labels + + print(len(self.paths), len(self.labels)) + + def _load_split_txt(self, split): + paths, labels = [], [] + with open(os.path.join(self.root, "Stanford_Online_Products", f"Ebay_{split}.txt")) as f: + for i, l in enumerate(f): + if i == 0: + continue + l_split = l.strip().split() + label, path = int(l_split[1]), l_split[3] + paths.append(os.path.join(self.root, "Stanford_Online_Products", path)) + labels.append(label) + return paths, labels + + def download_and_remove(self): + os.makedirs(self.root, exist_ok=True) + download_folder_path = os.path.join(self.root, StanfordOnlineProducts.DOWNLOAD_URL.split('/')[-1]) + _urlretrieve(url=StanfordOnlineProducts.DOWNLOAD_URL, filename=download_folder_path) + with zipfile.ZipFile(download_folder_path, "r") as zip_ref: + zip_ref.extractall(self.root) + os.remove(download_folder_path) + +# if __name__ == "__main__": +# train_dataset = StanfordOnlineProducts(root="data_sop", split="train", download=True) +# train_dataset = StanfordOnlineProducts(root="data_sop", split="test", download=True) +# train_dataset = StanfordOnlineProducts(root="data_sop", split="train+test", download=True) + From 4315b7a09ae10a6bc6c78b4fe4a5a5c3caa3008a Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sat, 26 Oct 2024 23:46:31 +0200 Subject: [PATCH 11/20] update paths --- .../datasets/cars196.py | 24 +++++++++---------- src/pytorch_metric_learning/datasets/cub.py | 7 +----- .../datasets/inaturalist2018.py | 17 ++++++------- src/pytorch_metric_learning/datasets/sop.py | 2 -- 4 files changed, 20 insertions(+), 30 deletions(-) diff --git a/src/pytorch_metric_learning/datasets/cars196.py b/src/pytorch_metric_learning/datasets/cars196.py index 7cefc20d..7f9c2d42 100644 --- a/src/pytorch_metric_learning/datasets/cars196.py +++ b/src/pytorch_metric_learning/datasets/cars196.py @@ -16,22 +16,26 @@ def generate_split(self): else: classes = set(range(1, 197)) + with open(os.path.join(self.root, "names.csv"), "r") as f: + names = [x.strip() for x in f.readlines()] + paths_train, labels_train = self._load_csv( - os.path.join(self.root, "anno_train.csv"), split="train" + os.path.join(self.root, "anno_train.csv"), names, split="train" ) paths_test, labels_test = self._load_csv( - os.path.join(self.root, "anno_test.csv"), split="test" + os.path.join(self.root, "anno_test.csv"), names, split="test" ) paths = paths_train + paths_test labels = labels_train + labels_test + self.paths, self.labels = [], [] for p, l in zip(paths, labels): if l in classes: self.paths.append(p) self.labels.append(l) - - def _load_csv(self, path, split): + + def _load_csv(self, path, names, split): all_paths, all_labels = [], [] with open(path, "r") as f: for l in f: @@ -39,7 +43,9 @@ def _load_csv(self, path, split): curr_path = path_annos[0] curr_label = path_annos[-1] all_paths.append( - os.path.join(self.root, "car_data", "car_data", split, curr_path) + os.path.join( + self.root, "car_data", "car_data", split, names[int(curr_label) - 1].replace("/","-"), curr_path + ) ) all_labels.append(int(curr_label)) return all_paths, all_labels @@ -50,10 +56,4 @@ def download_and_remove(self): _urlretrieve(url=Cars196.DOWNLOAD_URL, filename=download_folder_path) with zipfile.ZipFile(download_folder_path, 'r') as zip_ref: zip_ref.extractall(self.root) - os.remove(download_folder_path) - -# if __name__ == "__main__": -# train_dataset = Cars196(root="data_cars", split="train", download=True) -# test_dataset = Cars196(root="data_cars", split="test", download=True) -# train_test_dataset = Cars196(root="data_cars", split="train+test", download=False) -# print(len(train_dataset), len(test_dataset), len(train_test_dataset)) \ No newline at end of file + os.remove(download_folder_path) \ No newline at end of file diff --git a/src/pytorch_metric_learning/datasets/cub.py b/src/pytorch_metric_learning/datasets/cub.py index 104b18f4..4a64a31c 100644 --- a/src/pytorch_metric_learning/datasets/cub.py +++ b/src/pytorch_metric_learning/datasets/cub.py @@ -33,7 +33,7 @@ def generate_split(self): # If the image ids correspond it's a match if img_idx1 == img_idx2: - self.paths.append(img_path) + self.paths.append(os.path.join(self.root, dir_name, "images", img_path)) self.labels.append(class_idx) def download_and_remove(self): @@ -43,8 +43,3 @@ def download_and_remove(self): with tarfile.open(download_folder_path, "r:gz") as tar: tar.extractall(self.root) os.remove(download_folder_path) - - -# if __name__ == "__main__": -# train_dataset = CUB(root="data_cub", split="test", download=True) -# print(len(train_dataset)) \ No newline at end of file diff --git a/src/pytorch_metric_learning/datasets/inaturalist2018.py b/src/pytorch_metric_learning/datasets/inaturalist2018.py index 3d3cd004..eb72ff0f 100644 --- a/src/pytorch_metric_learning/datasets/inaturalist2018.py +++ b/src/pytorch_metric_learning/datasets/inaturalist2018.py @@ -13,8 +13,11 @@ class INaturalist2018(BaseDataset): SPLITS_URL = "https://drive.google.com/uc?id=1sXfkBTFDrRU3__-NUs1qBP3sf_0uMB98" def generate_split(self): - train_json = json.load(open(os.path.join(self.root, "train2018.json"))) - val_json = json.load(open(os.path.join(self.root, "val2018.json"))) + with open(os.path.join(self.root, "train2018.json"), "r") as train_f: + train_json = json.load(train_f) + + with open(os.path.join(self.root, "val2018.json"), "r") as val_f: + val_json = json.load(val_f) val_imgs, val_anns = val_json["images"], val_json["annotations"] train_imgs, train_anns = train_json["images"], train_json["annotations"] @@ -41,7 +44,7 @@ def generate_split(self): paths = train_paths + test_paths labels = train_labels + test_labels - self.paths = paths + self.paths = [os.path.join(self.root, p) for p in paths] self.labels = labels def _load_split_txt(self, split): @@ -74,10 +77,4 @@ def download_and_remove(self): _urlretrieve(url=INaturalist2018.SPLITS_URL, filename=download_folder_path) with zipfile.ZipFile(download_folder_path, "r") as zip_ref: zip_ref.extractall(self.root) - os.remove(download_folder_path) - -# if __name__ == "__main__": -# train_test_dataset = INaturalist2018(root="data", split="train+test", download=True) -# train_dataset = INaturalist2018(root="data", split="train", download=True) -# test_dataset = INaturalist2018(root="data", split="test", download=True) -# print(len(train_test_dataset), len(train_dataset), len(test_dataset)) \ No newline at end of file + os.remove(download_folder_path) \ No newline at end of file diff --git a/src/pytorch_metric_learning/datasets/sop.py b/src/pytorch_metric_learning/datasets/sop.py index 0a53ef62..3b0de133 100644 --- a/src/pytorch_metric_learning/datasets/sop.py +++ b/src/pytorch_metric_learning/datasets/sop.py @@ -19,8 +19,6 @@ def generate_split(self): self.paths = paths self.labels = labels - print(len(self.paths), len(self.labels)) - def _load_split_txt(self, split): paths, labels = [], [] with open(os.path.join(self.root, "Stanford_Online_Products", f"Ebay_{split}.txt")) as f: From e86eab08d19825186d0cee9e41107e58c0491302 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sat, 26 Oct 2024 23:46:41 +0200 Subject: [PATCH 12/20] add tests --- tests/datasets/test_cars196.py | 46 +++++++++++++++++++++++ tests/datasets/test_cub.py | 46 +++++++++++++++++++++++ tests/datasets/test_inaturalist2018.py | 52 ++++++++++++++++++++++++++ tests/datasets/test_sop.py | 49 ++++++++++++++++++++++++ 4 files changed, 193 insertions(+) create mode 100644 tests/datasets/test_cars196.py create mode 100644 tests/datasets/test_cub.py create mode 100644 tests/datasets/test_inaturalist2018.py create mode 100644 tests/datasets/test_sop.py diff --git a/tests/datasets/test_cars196.py b/tests/datasets/test_cars196.py new file mode 100644 index 00000000..95f400d9 --- /dev/null +++ b/tests/datasets/test_cars196.py @@ -0,0 +1,46 @@ +import unittest +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from pytorch_metric_learning.datasets.cars196 import Cars196 +import shutil +import os + +class TestCars196(unittest.TestCase): + + CARS_196_ROOT = "test_cars196" + ALREADY_EXISTS = False + + # In the rare case the user has an already existing directory, do not delete it + @classmethod + def setUpClass(cls): + if os.path.exists(cls.CARS_196_ROOT): + cls.ALREADY_EXISTS = True + + def test_Cars196(self): + train_test_data = Cars196(root=TestCars196.CARS_196_ROOT, split="train+test", download=True) + train_data = Cars196(root=TestCars196.CARS_196_ROOT, split="train", download=True) + test_data = Cars196(root=TestCars196.CARS_196_ROOT, split="test", download=False) + + self.assertTrue(len(train_test_data) == 16185) + self.assertTrue(len(train_data) == 8054) + self.assertTrue(len(test_data) == 8131) + + def test_CARS_196_dataloader(self): + test_data = Cars196( + root=TestCars196.CARS_196_ROOT, + transform=transforms.Compose([ + transforms.Resize(size=(224, 224)), + transforms.PILToTensor() + ]), + split="test", + download=True + ) + loader = DataLoader(test_data, batch_size=8) + inputs, labels = next(iter(loader)) + self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224)) + self.assertTupleEqual(tuple(labels.shape), (8,)) + + @classmethod + def tearDownClass(cls): + if not cls.ALREADY_EXISTS: + shutil.rmtree(cls.CARS_196_ROOT) \ No newline at end of file diff --git a/tests/datasets/test_cub.py b/tests/datasets/test_cub.py new file mode 100644 index 00000000..cc56501f --- /dev/null +++ b/tests/datasets/test_cub.py @@ -0,0 +1,46 @@ +import unittest +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from pytorch_metric_learning.datasets.cub import CUB +import shutil +import os + +class TestCUB(unittest.TestCase): + + CUB_ROOT = "test_cub" + ALREADY_EXISTS = False + + # In the rare case the user has an already existing directory, do not delete it + @classmethod + def setUpClass(cls): + if os.path.exists(cls.CUB_ROOT): + cls.ALREADY_EXISTS = True + + def test_CUB(self): + train_test_data = CUB(root=TestCUB.CUB_ROOT, split="train+test", download=True) + train_data = CUB(root=TestCUB.CUB_ROOT, split="train", download=True) + test_data = CUB(root=TestCUB.CUB_ROOT, split="test", download=False) + + self.assertTrue(len(train_test_data) == 11788) + self.assertTrue(len(train_data) == 5864) + self.assertTrue(len(test_data) == 5924) + + def test_CUB_dataloader(self): + test_data = CUB( + root=TestCUB.CUB_ROOT, + transform=transforms.Compose([ + transforms.Resize(size=(224, 224)), + transforms.PILToTensor() + ]), + split="test", + download=True + ) + loader = DataLoader(test_data, batch_size=8) + inputs, labels = next(iter(loader)) + self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224)) + self.assertTupleEqual(tuple(labels.shape), (8,)) + + @classmethod + def tearDownClass(cls): + if not cls.ALREADY_EXISTS: + shutil.rmtree(cls.CUB_ROOT) \ No newline at end of file diff --git a/tests/datasets/test_inaturalist2018.py b/tests/datasets/test_inaturalist2018.py new file mode 100644 index 00000000..4e10dc07 --- /dev/null +++ b/tests/datasets/test_inaturalist2018.py @@ -0,0 +1,52 @@ +import unittest +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from pytorch_metric_learning.datasets.inaturalist2018 import INaturalist2018 +import shutil +import os + +class TestINaturalist2018(unittest.TestCase): + + INATURALIST2018_ROOT = "data" + ALREADY_EXISTS = False + + # In the rare case the user has an already existing directory, do not delete it + @classmethod + def setUpClass(cls): + if os.path.exists(cls.INATURALIST2018_ROOT): + cls.ALREADY_EXISTS = True + + def test_INaturalist2018(self): + train_test_data = INaturalist2018( + root=TestINaturalist2018.INATURALIST2018_ROOT, split="train+test", download=True + ) + train_data = INaturalist2018( + root=TestINaturalist2018.INATURALIST2018_ROOT, split="train", download=True + ) + test_data = INaturalist2018( + root=TestINaturalist2018.INATURALIST2018_ROOT, split="test", download=False + ) + + self.assertTrue(len(train_test_data) == 461939) + self.assertTrue(len(train_data) == 325846) + self.assertTrue(len(test_data) == 136093) + + def test_INaturalist2018_dataloader(self): + test_data = INaturalist2018( + root=TestINaturalist2018.INATURALIST2018_ROOT, + transform=transforms.Compose([ + transforms.Resize(size=(224, 224)), + transforms.PILToTensor() + ]), + split="test", + download=True + ) + loader = DataLoader(test_data, batch_size=8) + inputs, labels = next(iter(loader)) + self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224)) + self.assertTupleEqual(tuple(labels.shape), (8,)) + + @classmethod + def tearDownClass(cls): + if not cls.ALREADY_EXISTS: + shutil.rmtree(cls.INATURALIST2018_ROOT) \ No newline at end of file diff --git a/tests/datasets/test_sop.py b/tests/datasets/test_sop.py new file mode 100644 index 00000000..d66a98af --- /dev/null +++ b/tests/datasets/test_sop.py @@ -0,0 +1,49 @@ +import unittest +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from pytorch_metric_learning.datasets.sop import StanfordOnlineProducts +import shutil +import os + +class TestStanfordOnlineProducts(unittest.TestCase): + + SOP_ROOT = "test_sop" + ALREADY_EXISTS = False + + # In the rare case the user has an already existing directory, do not delete it + @classmethod + def setUpClass(cls): + if os.path.exists(cls.SOP_ROOT): + cls.ALREADY_EXISTS = True + + def test_SOP(self): + train_test_data = StanfordOnlineProducts( + root=TestStanfordOnlineProducts.SOP_ROOT, split="train+test", download=True) + train_data = StanfordOnlineProducts( + root=TestStanfordOnlineProducts.SOP_ROOT, split="train", download=True) + test_data = StanfordOnlineProducts( + root=TestStanfordOnlineProducts.SOP_ROOT, split="test", download=False) + + self.assertTrue(len(train_test_data) == 120053) + self.assertTrue(len(train_data) == 59551) + self.assertTrue(len(test_data) == 60502) + + def test_SOP_dataloader(self): + test_data = StanfordOnlineProducts( + root=TestStanfordOnlineProducts.SOP_ROOT, + transform=transforms.Compose([ + transforms.Resize(size=(224, 224)), + transforms.PILToTensor() + ]), + split="test", + download=True + ) + loader = DataLoader(test_data, batch_size=8) + inputs, labels = next(iter(loader)) + self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224)) + self.assertTupleEqual(tuple(labels.shape), (8,)) + + @classmethod + def tearDownClass(cls): + if not cls.ALREADY_EXISTS: + shutil.rmtree(cls.SOP_ROOT) \ No newline at end of file From b629fffac638d4aa64cce01970789d17be1aa6d3 Mon Sep 17 00:00:00 2001 From: ir2718 Date: Sun, 27 Oct 2024 00:42:36 +0200 Subject: [PATCH 13/20] format code --- .../datasets/base_dataset.py | 30 +++++++---- .../datasets/cars196.py | 26 ++++++---- src/pytorch_metric_learning/datasets/cub.py | 18 ++++--- .../datasets/inaturalist2018.py | 50 +++++++++++++------ src/pytorch_metric_learning/datasets/sop.py | 26 ++++++---- .../samplers/m_per_class_sampler.py | 2 + .../utils/common_functions.py | 6 ++- tests/datasets/test_cars196.py | 36 +++++++------ tests/datasets/test_cub.py | 24 +++++---- tests/datasets/test_inaturalist2018.py | 28 ++++++----- tests/datasets/test_sop.py | 33 ++++++------ 11 files changed, 177 insertions(+), 102 deletions(-) diff --git a/src/pytorch_metric_learning/datasets/base_dataset.py b/src/pytorch_metric_learning/datasets/base_dataset.py index 1af185ca..c9613b99 100644 --- a/src/pytorch_metric_learning/datasets/base_dataset.py +++ b/src/pytorch_metric_learning/datasets/base_dataset.py @@ -1,11 +1,20 @@ -from PIL import Image -from torch.utils.data import Dataset import os from abc import ABC, abstractmethod +from PIL import Image +from torch.utils.data import Dataset + + class BaseDataset(ABC, Dataset): - def __init__(self, root, split="train+test", transform=None, target_transform=None, download=False): + def __init__( + self, + root, + split="train+test", + transform=None, + target_transform=None, + download=False, + ): self.root = root if download: @@ -18,7 +27,8 @@ def __init__(self, root, split="train+test", transform=None, target_transform=No # The given directory does not exist so the user should be aware of downloading it # Otherwise proceed as usual if not os.path.isdir(self.root): - raise ValueError("The given path does not exist. " + raise ValueError( + "The given path does not exist. " "You should probably initialize the dataset with download=True." ) @@ -26,8 +36,10 @@ def __init__(self, root, split="train+test", transform=None, target_transform=No self.target_transform = target_transform if split not in self.get_available_splits(): - raise ValueError(f"Supported splits are: {', '.join(self.get_available_splits())}") - + raise ValueError( + f"Supported splits are: {', '.join(self.get_available_splits())}" + ) + self.split = split self.generate_split() @@ -35,7 +47,7 @@ def __init__(self, root, split="train+test", transform=None, target_transform=No @abstractmethod def generate_split(): raise NotImplementedError - + @abstractmethod def download_and_remove(): raise NotImplementedError @@ -45,7 +57,7 @@ def get_available_splits(self): def __len__(self): return len(self.labels) - + def __getitem__(self, idx): img = Image.open(self.paths[idx]) label = self.labels[idx] @@ -56,4 +68,4 @@ def __getitem__(self, idx): if self.target_transform is not None: label = self.target_transform(label) - return (img, label) \ No newline at end of file + return (img, label) diff --git a/src/pytorch_metric_learning/datasets/cars196.py b/src/pytorch_metric_learning/datasets/cars196.py index 7f9c2d42..9e71e6aa 100644 --- a/src/pytorch_metric_learning/datasets/cars196.py +++ b/src/pytorch_metric_learning/datasets/cars196.py @@ -1,8 +1,10 @@ -from ..datasets.base_dataset import BaseDataset -from ..utils.common_functions import _urlretrieve import os import zipfile +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve + + class Cars196(BaseDataset): DOWNLOAD_URL = "https://www.kaggle.com/api/v1/datasets/download/jutrera/stanford-car-dataset-by-classes-folder" @@ -15,7 +17,7 @@ def generate_split(self): classes = set(range(99, 197)) else: classes = set(range(1, 197)) - + with open(os.path.join(self.root, "names.csv"), "r") as f: names = [x.strip() for x in f.readlines()] @@ -28,13 +30,12 @@ def generate_split(self): paths = paths_train + paths_test labels = labels_train + labels_test - self.paths, self.labels = [], [] for p, l in zip(paths, labels): if l in classes: self.paths.append(p) self.labels.append(l) - + def _load_csv(self, path, names, split): all_paths, all_labels = [], [] with open(path, "r") as f: @@ -44,7 +45,12 @@ def _load_csv(self, path, names, split): curr_label = path_annos[-1] all_paths.append( os.path.join( - self.root, "car_data", "car_data", split, names[int(curr_label) - 1].replace("/","-"), curr_path + self.root, + "car_data", + "car_data", + split, + names[int(curr_label) - 1].replace("/", "-"), + curr_path, ) ) all_labels.append(int(curr_label)) @@ -52,8 +58,10 @@ def _load_csv(self, path, names, split): def download_and_remove(self): os.makedirs(self.root, exist_ok=True) - download_folder_path = os.path.join(self.root, Cars196.DOWNLOAD_URL.split('/')[-1]) + download_folder_path = os.path.join( + self.root, Cars196.DOWNLOAD_URL.split("/")[-1] + ) _urlretrieve(url=Cars196.DOWNLOAD_URL, filename=download_folder_path) - with zipfile.ZipFile(download_folder_path, 'r') as zip_ref: + with zipfile.ZipFile(download_folder_path, "r") as zip_ref: zip_ref.extractall(self.root) - os.remove(download_folder_path) \ No newline at end of file + os.remove(download_folder_path) diff --git a/src/pytorch_metric_learning/datasets/cub.py b/src/pytorch_metric_learning/datasets/cub.py index 4a64a31c..61f0d664 100644 --- a/src/pytorch_metric_learning/datasets/cub.py +++ b/src/pytorch_metric_learning/datasets/cub.py @@ -1,14 +1,16 @@ -from ..datasets.base_dataset import BaseDataset -from ..utils.common_functions import _urlretrieve import os import tarfile +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve + + class CUB(BaseDataset): DOWNLOAD_URL = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz" def generate_split(self): - dir_name = CUB.DOWNLOAD_URL.split('/')[-1].replace(".tgz", "") + dir_name = CUB.DOWNLOAD_URL.split("/")[-1].replace(".tgz", "") # Training split is first 100 classes, other 100 is test if self.split == "train": @@ -17,14 +19,14 @@ def generate_split(self): classes = set(range(101, 201)) else: classes = set(range(1, 201)) - + # Find ids which correspond to the classes in the split self.paths, self.labels = [], [] with open(os.path.join(self.root, dir_name, "image_class_labels.txt")) as f1: with open(os.path.join(self.root, dir_name, "images.txt")) as f2: for l1, l2 in zip(f1, f2): img_idx1, class_idx = list(map(int, l1.split())) - + if class_idx not in classes: continue @@ -33,12 +35,14 @@ def generate_split(self): # If the image ids correspond it's a match if img_idx1 == img_idx2: - self.paths.append(os.path.join(self.root, dir_name, "images", img_path)) + self.paths.append( + os.path.join(self.root, dir_name, "images", img_path) + ) self.labels.append(class_idx) def download_and_remove(self): os.makedirs(self.root, exist_ok=True) - download_folder_path = os.path.join(self.root, CUB.DOWNLOAD_URL.split('/')[-1]) + download_folder_path = os.path.join(self.root, CUB.DOWNLOAD_URL.split("/")[-1]) _urlretrieve(url=CUB.DOWNLOAD_URL, filename=download_folder_path) with tarfile.open(download_folder_path, "r:gz") as tar: tar.extractall(self.root) diff --git a/src/pytorch_metric_learning/datasets/inaturalist2018.py b/src/pytorch_metric_learning/datasets/inaturalist2018.py index eb72ff0f..aeb4b7f9 100644 --- a/src/pytorch_metric_learning/datasets/inaturalist2018.py +++ b/src/pytorch_metric_learning/datasets/inaturalist2018.py @@ -1,15 +1,19 @@ -from ..datasets.base_dataset import BaseDataset -from ..utils.common_functions import _urlretrieve +import json import os import tarfile import zipfile -import json + +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve + class INaturalist2018(BaseDataset): IMG_DOWNLOAD_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz" TRAIN_ANN_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train2018.json.tar.gz" - VAL_ANN_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/val2018.json.tar.gz" + VAL_ANN_URL = ( + "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/val2018.json.tar.gz" + ) SPLITS_URL = "https://drive.google.com/uc?id=1sXfkBTFDrRU3__-NUs1qBP3sf_0uMB98" def generate_split(self): @@ -24,8 +28,8 @@ def generate_split(self): imgs, anns = val_imgs + train_imgs, val_anns + train_anns - path2id = {x["file_name"]:x["id"] for x in imgs} - id2label = {x["image_id"]:x["category_id"] for x in anns} + path2id = {x["file_name"]: x["id"] for x in imgs} + id2label = {x["image_id"]: x["category_id"] for x in anns} if self.split in ["train", "test"]: paths = self._load_split_txt(self.split) @@ -36,7 +40,7 @@ def generate_split(self): train_paths = self._load_split_txt("train") train_ids = [path2id[p] for p in train_paths] train_labels = [id2label[i] for i in train_ids] - + test_paths = self._load_split_txt("test") test_ids = [path2id[p] for p in test_paths] test_labels = [id2label[i] for i in test_ids] @@ -49,32 +53,46 @@ def generate_split(self): def _load_split_txt(self, split): paths = [] - with open(os.path.join(self.root, "Inat_dataset_splits", f"Inaturalist_{split}_set1.txt")) as f: + with open( + os.path.join( + self.root, "Inat_dataset_splits", f"Inaturalist_{split}_set1.txt" + ) + ) as f: for l in f: paths.append(l.strip()) return paths - + def download_and_remove(self): - download_folder_path = os.path.join(self.root, INaturalist2018.IMG_DOWNLOAD_URL.split('/')[-1]) - _urlretrieve(url=INaturalist2018.IMG_DOWNLOAD_URL, filename=download_folder_path) + download_folder_path = os.path.join( + self.root, INaturalist2018.IMG_DOWNLOAD_URL.split("/")[-1] + ) + _urlretrieve( + url=INaturalist2018.IMG_DOWNLOAD_URL, filename=download_folder_path + ) with tarfile.open(download_folder_path, "r:gz") as tar: tar.extractall(self.root) os.remove(download_folder_path) - - download_folder_path = os.path.join(self.root, INaturalist2018.TRAIN_ANN_URL.split('/')[-1]) + + download_folder_path = os.path.join( + self.root, INaturalist2018.TRAIN_ANN_URL.split("/")[-1] + ) _urlretrieve(url=INaturalist2018.TRAIN_ANN_URL, filename=download_folder_path) with tarfile.open(download_folder_path, "r:gz") as tar: tar.extractall(self.root) os.remove(download_folder_path) - download_folder_path = os.path.join(self.root, INaturalist2018.VAL_ANN_URL.split('/')[-1]) + download_folder_path = os.path.join( + self.root, INaturalist2018.VAL_ANN_URL.split("/")[-1] + ) _urlretrieve(url=INaturalist2018.VAL_ANN_URL, filename=download_folder_path) with tarfile.open(download_folder_path, "r:gz") as tar: tar.extractall(self.root) os.remove(download_folder_path) - download_folder_path = os.path.join(self.root, INaturalist2018.SPLITS_URL.split('/')[-1]) + download_folder_path = os.path.join( + self.root, INaturalist2018.SPLITS_URL.split("/")[-1] + ) _urlretrieve(url=INaturalist2018.SPLITS_URL, filename=download_folder_path) with zipfile.ZipFile(download_folder_path, "r") as zip_ref: zip_ref.extractall(self.root) - os.remove(download_folder_path) \ No newline at end of file + os.remove(download_folder_path) diff --git a/src/pytorch_metric_learning/datasets/sop.py b/src/pytorch_metric_learning/datasets/sop.py index 3b0de133..335f3b4b 100644 --- a/src/pytorch_metric_learning/datasets/sop.py +++ b/src/pytorch_metric_learning/datasets/sop.py @@ -1,8 +1,10 @@ -from ..datasets.base_dataset import BaseDataset -from ..utils.common_functions import _urlretrieve import os import zipfile +from ..datasets.base_dataset import BaseDataset +from ..utils.common_functions import _urlretrieve + + class StanfordOnlineProducts(BaseDataset): DOWNLOAD_URL = "https://drive.usercontent.google.com/download?id=1TclrpQOF_ullUP99wk_gjGN8pKvtErG8&export=download&authuser=0&confirm=t" @@ -21,26 +23,32 @@ def generate_split(self): def _load_split_txt(self, split): paths, labels = [], [] - with open(os.path.join(self.root, "Stanford_Online_Products", f"Ebay_{split}.txt")) as f: + with open( + os.path.join(self.root, "Stanford_Online_Products", f"Ebay_{split}.txt") + ) as f: for i, l in enumerate(f): if i == 0: continue l_split = l.strip().split() - label, path = int(l_split[1]), l_split[3] + label, path = int(l_split[1]), l_split[3] paths.append(os.path.join(self.root, "Stanford_Online_Products", path)) labels.append(label) return paths, labels - + def download_and_remove(self): os.makedirs(self.root, exist_ok=True) - download_folder_path = os.path.join(self.root, StanfordOnlineProducts.DOWNLOAD_URL.split('/')[-1]) - _urlretrieve(url=StanfordOnlineProducts.DOWNLOAD_URL, filename=download_folder_path) + download_folder_path = os.path.join( + self.root, StanfordOnlineProducts.DOWNLOAD_URL.split("/")[-1] + ) + _urlretrieve( + url=StanfordOnlineProducts.DOWNLOAD_URL, filename=download_folder_path + ) with zipfile.ZipFile(download_folder_path, "r") as zip_ref: zip_ref.extractall(self.root) os.remove(download_folder_path) - + + # if __name__ == "__main__": # train_dataset = StanfordOnlineProducts(root="data_sop", split="train", download=True) # train_dataset = StanfordOnlineProducts(root="data_sop", split="test", download=True) # train_dataset = StanfordOnlineProducts(root="data_sop", split="train+test", download=True) - diff --git a/src/pytorch_metric_learning/samplers/m_per_class_sampler.py b/src/pytorch_metric_learning/samplers/m_per_class_sampler.py index c0444c76..1b263ac3 100644 --- a/src/pytorch_metric_learning/samplers/m_per_class_sampler.py +++ b/src/pytorch_metric_learning/samplers/m_per_class_sampler.py @@ -48,12 +48,14 @@ def __iter__(self): curr_label_set = self.labels else: curr_label_set = self.labels[: self.batch_size // self.m_per_class] + print(curr_label_set) for label in curr_label_set: t = self.labels_to_indices[label] idx_list[i : i + self.m_per_class] = c_f.safe_random_choice( t, size=self.m_per_class ) i += self.m_per_class + return iter(idx_list) def calculate_num_iters(self): diff --git a/src/pytorch_metric_learning/utils/common_functions.py b/src/pytorch_metric_learning/utils/common_functions.py index 3b07803f..424f00a3 100644 --- a/src/pytorch_metric_learning/utils/common_functions.py +++ b/src/pytorch_metric_learning/utils/common_functions.py @@ -15,15 +15,19 @@ NUMPY_RANDOM = np.random COLLECT_STATS = False + # taken from: # https://github.com/pytorch/vision/blob/main/torchvision/datasets/utils.py#L27 def _urlretrieve(url, filename, chunk_size=1024 * 32): with urllib.request.urlopen(urllib.request.Request(url)) as response: - with open(filename, "wb") as fh, tqdm(total=response.length, unit="B", unit_scale=True) as pbar: + with open(filename, "wb") as fh, tqdm( + total=response.length, unit="B", unit_scale=True + ) as pbar: while chunk := response.read(chunk_size): fh.write(chunk) pbar.update(len(chunk)) + def set_logger_name(name): global LOGGER_NAME global LOGGER diff --git a/tests/datasets/test_cars196.py b/tests/datasets/test_cars196.py index 95f400d9..43d1b4d9 100644 --- a/tests/datasets/test_cars196.py +++ b/tests/datasets/test_cars196.py @@ -1,12 +1,15 @@ +import os +import shutil import unittest + import torchvision.transforms as transforms from torch.utils.data import DataLoader + from pytorch_metric_learning.datasets.cars196 import Cars196 -import shutil -import os + class TestCars196(unittest.TestCase): - + CARS_196_ROOT = "test_cars196" ALREADY_EXISTS = False @@ -17,9 +20,15 @@ def setUpClass(cls): cls.ALREADY_EXISTS = True def test_Cars196(self): - train_test_data = Cars196(root=TestCars196.CARS_196_ROOT, split="train+test", download=True) - train_data = Cars196(root=TestCars196.CARS_196_ROOT, split="train", download=True) - test_data = Cars196(root=TestCars196.CARS_196_ROOT, split="test", download=False) + train_test_data = Cars196( + root=TestCars196.CARS_196_ROOT, split="train+test", download=True + ) + train_data = Cars196( + root=TestCars196.CARS_196_ROOT, split="train", download=True + ) + test_data = Cars196( + root=TestCars196.CARS_196_ROOT, split="test", download=False + ) self.assertTrue(len(train_test_data) == 16185) self.assertTrue(len(train_data) == 8054) @@ -27,13 +36,12 @@ def test_Cars196(self): def test_CARS_196_dataloader(self): test_data = Cars196( - root=TestCars196.CARS_196_ROOT, - transform=transforms.Compose([ - transforms.Resize(size=(224, 224)), - transforms.PILToTensor() - ]), - split="test", - download=True + root=TestCars196.CARS_196_ROOT, + transform=transforms.Compose( + [transforms.Resize(size=(224, 224)), transforms.PILToTensor()] + ), + split="test", + download=True, ) loader = DataLoader(test_data, batch_size=8) inputs, labels = next(iter(loader)) @@ -43,4 +51,4 @@ def test_CARS_196_dataloader(self): @classmethod def tearDownClass(cls): if not cls.ALREADY_EXISTS: - shutil.rmtree(cls.CARS_196_ROOT) \ No newline at end of file + shutil.rmtree(cls.CARS_196_ROOT) diff --git a/tests/datasets/test_cub.py b/tests/datasets/test_cub.py index cc56501f..625e97d2 100644 --- a/tests/datasets/test_cub.py +++ b/tests/datasets/test_cub.py @@ -1,12 +1,15 @@ +import os +import shutil import unittest + import torchvision.transforms as transforms from torch.utils.data import DataLoader + from pytorch_metric_learning.datasets.cub import CUB -import shutil -import os + class TestCUB(unittest.TestCase): - + CUB_ROOT = "test_cub" ALREADY_EXISTS = False @@ -27,13 +30,12 @@ def test_CUB(self): def test_CUB_dataloader(self): test_data = CUB( - root=TestCUB.CUB_ROOT, - transform=transforms.Compose([ - transforms.Resize(size=(224, 224)), - transforms.PILToTensor() - ]), - split="test", - download=True + root=TestCUB.CUB_ROOT, + transform=transforms.Compose( + [transforms.Resize(size=(224, 224)), transforms.PILToTensor()] + ), + split="test", + download=True, ) loader = DataLoader(test_data, batch_size=8) inputs, labels = next(iter(loader)) @@ -43,4 +45,4 @@ def test_CUB_dataloader(self): @classmethod def tearDownClass(cls): if not cls.ALREADY_EXISTS: - shutil.rmtree(cls.CUB_ROOT) \ No newline at end of file + shutil.rmtree(cls.CUB_ROOT) diff --git a/tests/datasets/test_inaturalist2018.py b/tests/datasets/test_inaturalist2018.py index 4e10dc07..123dfd64 100644 --- a/tests/datasets/test_inaturalist2018.py +++ b/tests/datasets/test_inaturalist2018.py @@ -1,12 +1,15 @@ +import os +import shutil import unittest + import torchvision.transforms as transforms from torch.utils.data import DataLoader + from pytorch_metric_learning.datasets.inaturalist2018 import INaturalist2018 -import shutil -import os + class TestINaturalist2018(unittest.TestCase): - + INATURALIST2018_ROOT = "data" ALREADY_EXISTS = False @@ -18,7 +21,9 @@ def setUpClass(cls): def test_INaturalist2018(self): train_test_data = INaturalist2018( - root=TestINaturalist2018.INATURALIST2018_ROOT, split="train+test", download=True + root=TestINaturalist2018.INATURALIST2018_ROOT, + split="train+test", + download=True, ) train_data = INaturalist2018( root=TestINaturalist2018.INATURALIST2018_ROOT, split="train", download=True @@ -33,13 +38,12 @@ def test_INaturalist2018(self): def test_INaturalist2018_dataloader(self): test_data = INaturalist2018( - root=TestINaturalist2018.INATURALIST2018_ROOT, - transform=transforms.Compose([ - transforms.Resize(size=(224, 224)), - transforms.PILToTensor() - ]), - split="test", - download=True + root=TestINaturalist2018.INATURALIST2018_ROOT, + transform=transforms.Compose( + [transforms.Resize(size=(224, 224)), transforms.PILToTensor()] + ), + split="test", + download=True, ) loader = DataLoader(test_data, batch_size=8) inputs, labels = next(iter(loader)) @@ -49,4 +53,4 @@ def test_INaturalist2018_dataloader(self): @classmethod def tearDownClass(cls): if not cls.ALREADY_EXISTS: - shutil.rmtree(cls.INATURALIST2018_ROOT) \ No newline at end of file + shutil.rmtree(cls.INATURALIST2018_ROOT) diff --git a/tests/datasets/test_sop.py b/tests/datasets/test_sop.py index d66a98af..77142251 100644 --- a/tests/datasets/test_sop.py +++ b/tests/datasets/test_sop.py @@ -1,12 +1,15 @@ +import os +import shutil import unittest + import torchvision.transforms as transforms from torch.utils.data import DataLoader + from pytorch_metric_learning.datasets.sop import StanfordOnlineProducts -import shutil -import os + class TestStanfordOnlineProducts(unittest.TestCase): - + SOP_ROOT = "test_sop" ALREADY_EXISTS = False @@ -18,11 +21,14 @@ def setUpClass(cls): def test_SOP(self): train_test_data = StanfordOnlineProducts( - root=TestStanfordOnlineProducts.SOP_ROOT, split="train+test", download=True) + root=TestStanfordOnlineProducts.SOP_ROOT, split="train+test", download=True + ) train_data = StanfordOnlineProducts( - root=TestStanfordOnlineProducts.SOP_ROOT, split="train", download=True) + root=TestStanfordOnlineProducts.SOP_ROOT, split="train", download=True + ) test_data = StanfordOnlineProducts( - root=TestStanfordOnlineProducts.SOP_ROOT, split="test", download=False) + root=TestStanfordOnlineProducts.SOP_ROOT, split="test", download=False + ) self.assertTrue(len(train_test_data) == 120053) self.assertTrue(len(train_data) == 59551) @@ -30,13 +36,12 @@ def test_SOP(self): def test_SOP_dataloader(self): test_data = StanfordOnlineProducts( - root=TestStanfordOnlineProducts.SOP_ROOT, - transform=transforms.Compose([ - transforms.Resize(size=(224, 224)), - transforms.PILToTensor() - ]), - split="test", - download=True + root=TestStanfordOnlineProducts.SOP_ROOT, + transform=transforms.Compose( + [transforms.Resize(size=(224, 224)), transforms.PILToTensor()] + ), + split="test", + download=True, ) loader = DataLoader(test_data, batch_size=8) inputs, labels = next(iter(loader)) @@ -46,4 +51,4 @@ def test_SOP_dataloader(self): @classmethod def tearDownClass(cls): if not cls.ALREADY_EXISTS: - shutil.rmtree(cls.SOP_ROOT) \ No newline at end of file + shutil.rmtree(cls.SOP_ROOT) From ad585a2ec5a877a46e36b6197f916d32be0860b5 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Wed, 11 Dec 2024 16:18:55 +0000 Subject: [PATCH 14/20] update mkdocs.yml --- mkdocs.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mkdocs.yml b/mkdocs.yml index a05c6d4c..731fc84c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,6 +1,7 @@ site_name: PyTorch Metric Learning nav: - Home: index.md + - Datasets: datasets.md - Distances: distances.md - Losses: losses.md - Miners: miners.md @@ -16,6 +17,7 @@ nav: - Common Functions: common_functions.md - Distributed: distributed.md - How to extend this library: + - Custom datasets: extend/datasets.md - Custom losses: extend/losses.md - Custom miners: extend/miners.md - Frequently Asked Questions: faq.md From cc1148893b6364ebc1c24d1d90fcc74bf5ae06a3 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Wed, 11 Dec 2024 16:21:34 +0000 Subject: [PATCH 15/20] delete print statement --- src/pytorch_metric_learning/samplers/m_per_class_sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/pytorch_metric_learning/samplers/m_per_class_sampler.py b/src/pytorch_metric_learning/samplers/m_per_class_sampler.py index 1b263ac3..c0444c76 100644 --- a/src/pytorch_metric_learning/samplers/m_per_class_sampler.py +++ b/src/pytorch_metric_learning/samplers/m_per_class_sampler.py @@ -48,14 +48,12 @@ def __iter__(self): curr_label_set = self.labels else: curr_label_set = self.labels[: self.batch_size // self.m_per_class] - print(curr_label_set) for label in curr_label_set: t = self.labels_to_indices[label] idx_list[i : i + self.m_per_class] = c_f.safe_random_choice( t, size=self.m_per_class ) i += self.m_per_class - return iter(idx_list) def calculate_num_iters(self): From b64c9b2fc1f7c48a57e99f60dc1184d6f7be9676 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Wed, 11 Dec 2024 16:27:29 +0000 Subject: [PATCH 16/20] add comments indicating where the google drive links are from --- src/pytorch_metric_learning/datasets/inaturalist2018.py | 2 ++ src/pytorch_metric_learning/datasets/sop.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/pytorch_metric_learning/datasets/inaturalist2018.py b/src/pytorch_metric_learning/datasets/inaturalist2018.py index aeb4b7f9..478d8a20 100644 --- a/src/pytorch_metric_learning/datasets/inaturalist2018.py +++ b/src/pytorch_metric_learning/datasets/inaturalist2018.py @@ -14,6 +14,8 @@ class INaturalist2018(BaseDataset): VAL_ANN_URL = ( "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/val2018.json.tar.gz" ) + + # Link from https://github.com/Andrew-Brown1/Smooth_AP?tab=readme-ov-file#data SPLITS_URL = "https://drive.google.com/uc?id=1sXfkBTFDrRU3__-NUs1qBP3sf_0uMB98" def generate_split(self): diff --git a/src/pytorch_metric_learning/datasets/sop.py b/src/pytorch_metric_learning/datasets/sop.py index 335f3b4b..cf298a54 100644 --- a/src/pytorch_metric_learning/datasets/sop.py +++ b/src/pytorch_metric_learning/datasets/sop.py @@ -7,6 +7,7 @@ class StanfordOnlineProducts(BaseDataset): + # Link from https://github.com/rksltnl/Deep-Metric-Learning-CVPR16?tab=readme-ov-file#stanford-online-products-dataset DOWNLOAD_URL = "https://drive.usercontent.google.com/download?id=1TclrpQOF_ullUP99wk_gjGN8pKvtErG8&export=download&authuser=0&confirm=t" def generate_split(self): From bf156a0ec2d8ac366660b6c052ff153a2c132a67 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Wed, 11 Dec 2024 16:30:11 +0000 Subject: [PATCH 17/20] add test_datasets.yml --- .github/workflows/test_datasets.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 .github/workflows/test_datasets.yml diff --git a/.github/workflows/test_datasets.yml b/.github/workflows/test_datasets.yml new file mode 100644 index 00000000..359ea8e5 --- /dev/null +++ b/.github/workflows/test_datasets.yml @@ -0,0 +1,15 @@ +name: datasets + +on: + pull_request: + branches: [ master, dev ] + paths: + - 'src/**' + - 'tests/**' + - '.github/workflows/**' + +jobs: + call-base-test-workflow: + uses: ./.github/workflows/base_test_workflow.yml + with: + module-to-test: datasets \ No newline at end of file From 9b4d5264e1745bf4728b0d6184d31de5f1506b39 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Wed, 11 Dec 2024 16:33:36 +0000 Subject: [PATCH 18/20] add __init__.py to tests/datasets --- tests/datasets/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/datasets/__init__.py diff --git a/tests/datasets/__init__.py b/tests/datasets/__init__.py new file mode 100644 index 00000000..e69de29b From 539c93ab543b60214d90c901c5c16f366d399c98 Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Wed, 11 Dec 2024 16:44:33 +0000 Subject: [PATCH 19/20] remove test_datasets.yml because it takes too long to run on github --- .github/workflows/test_datasets.yml | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 .github/workflows/test_datasets.yml diff --git a/.github/workflows/test_datasets.yml b/.github/workflows/test_datasets.yml deleted file mode 100644 index 359ea8e5..00000000 --- a/.github/workflows/test_datasets.yml +++ /dev/null @@ -1,15 +0,0 @@ -name: datasets - -on: - pull_request: - branches: [ master, dev ] - paths: - - 'src/**' - - 'tests/**' - - '.github/workflows/**' - -jobs: - call-base-test-workflow: - uses: ./.github/workflows/base_test_workflow.yml - with: - module-to-test: datasets \ No newline at end of file From c46dd3349b1c70640bdd3616f444bf45d0370abf Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Wed, 11 Dec 2024 16:49:23 +0000 Subject: [PATCH 20/20] bump version --- src/pytorch_metric_learning/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index 2614ce9d..892994aa 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.7.0" +__version__ = "2.8.0"