Skip to content

Commit 7d21152

Browse files
committed
fold -> partition
1 parent 9e4476a commit 7d21152

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

test/test_datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2212,7 +2212,7 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
22122212
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
22132213
split=("train", "test", "val"),
22142214
# There is no need to test the whole matrix here, since each fold is treated exactly the same
2215-
fold=(5,),
2215+
partition=(1, 5, 10),
22162216
)
22172217

22182218
def inject_fake_data(self, tmpdir: str, config):
@@ -2235,7 +2235,7 @@ def inject_fake_data(self, tmpdir: str, config):
22352235
meta_folder.mkdir()
22362236
image_ids = [str(path.relative_to(path.parents[1])).replace(os.sep, "/") for path in image_files]
22372237
image_ids_in_config = random.choices(image_ids, k=len(image_files) // 2)
2238-
with open(meta_folder / f"{config['split']}{config['fold']}.txt", "w") as file:
2238+
with open(meta_folder / f"{config['split']}{config['partition']}.txt", "w") as file:
22392239
file.write("\n".join(image_ids_in_config) + "\n")
22402240

22412241
return len(image_ids_in_config)

torchvision/datasets/dtd.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import pathlib
3-
from typing import Optional, Callable, Union
3+
from typing import Optional, Callable
44

55
import PIL.Image
66

@@ -14,7 +14,13 @@ class DTD(VisionDataset):
1414
Args:
1515
root (string): Root directory of the dataset.
1616
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
17-
fold (string or int, optional): The dataset fold. Should be ``1 <= fold <= 10``. Defaults to ``1``.
17+
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
18+
19+
.. note::
20+
21+
The partition only changes which split each image belongs to. Thus, regardless of the selected
22+
partition, combining all splits will result in all images.
23+
1824
download (bool, optional): If True, downloads the dataset from the internet and
1925
puts it in root directory. If dataset is already downloaded, it is not
2026
downloaded again.
@@ -30,13 +36,18 @@ def __init__(
3036
self,
3137
root: str,
3238
split: str = "train",
33-
fold: Union[str, int] = 1,
39+
partition: int = 1,
3440
download: bool = True,
3541
transform: Optional[Callable] = None,
3642
target_transform: Optional[Callable] = None,
3743
) -> None:
3844
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
39-
self._fold = verify_str_arg(str(fold), "fold", [str(i) for i in range(1, 11)])
45+
if not isinstance(partition, int) and not (1 <= partition <= 10):
46+
raise ValueError(
47+
f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
48+
f"but got {partition} instead"
49+
)
50+
self._partition = partition
4051

4152
super().__init__(root, transform=transform, target_transform=target_transform)
4253
self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
@@ -52,7 +63,7 @@ def __init__(
5263

5364
self._image_files = []
5465
classes = []
55-
with open(self._meta_folder / f"{self._split}{self._fold}.txt") as file:
66+
with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
5667
for line in file:
5768
cls, name = line.strip().split("/")
5869
self._image_files.append(self._images_folder.joinpath(cls, name))
@@ -78,7 +89,7 @@ def __getitem__(self, idx):
7889
return image, label
7990

8091
def extra_repr(self) -> str:
81-
return f"split={self._split}, fold={self._fold}"
92+
return f"split={self._split}, partition={self._partition}"
8293

8394
def _check_exists(self) -> bool:
8495
return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)

0 commit comments

Comments
 (0)