11import os
22import pathlib
3- from typing import Optional , Callable , Union
3+ from typing import Optional , Callable
44
55import PIL .Image
66
@@ -14,7 +14,13 @@ class DTD(VisionDataset):
1414 Args:
1515 root (string): Root directory of the dataset.
1616 split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
17- fold (string or int, optional): The dataset fold. Should be ``1 <= fold <= 10``. Defaults to ``1``.
17+ partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
18+
19+ .. note::
20+
21+ The partition only changes which split each image belongs to. Thus, regardless of the selected
22+ partition, combining all splits will result in all images.
23+
1824 download (bool, optional): If True, downloads the dataset from the internet and
1925 puts it in root directory. If dataset is already downloaded, it is not
2026 downloaded again.
@@ -30,13 +36,18 @@ def __init__(
3036 self ,
3137 root : str ,
3238 split : str = "train" ,
33- fold : Union [ str , int ] = 1 ,
39+ partition : int = 1 ,
3440 download : bool = True ,
3541 transform : Optional [Callable ] = None ,
3642 target_transform : Optional [Callable ] = None ,
3743 ) -> None :
3844 self ._split = verify_str_arg (split , "split" , ("train" , "val" , "test" ))
39- self ._fold = verify_str_arg (str (fold ), "fold" , [str (i ) for i in range (1 , 11 )])
45+ if not isinstance (partition , int ) and not (1 <= partition <= 10 ):
46+ raise ValueError (
47+ f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
48+ f"but got { partition } instead"
49+ )
50+ self ._partition = partition
4051
4152 super ().__init__ (root , transform = transform , target_transform = target_transform )
4253 self ._base_folder = pathlib .Path (self .root ) / type (self ).__name__ .lower ()
@@ -52,7 +63,7 @@ def __init__(
5263
5364 self ._image_files = []
5465 classes = []
55- with open (self ._meta_folder / f"{ self ._split } { self ._fold } .txt" ) as file :
66+ with open (self ._meta_folder / f"{ self ._split } { self ._partition } .txt" ) as file :
5667 for line in file :
5768 cls , name = line .strip ().split ("/" )
5869 self ._image_files .append (self ._images_folder .joinpath (cls , name ))
@@ -78,7 +89,7 @@ def __getitem__(self, idx):
7889 return image , label
7990
8091 def extra_repr (self ) -> str :
81- return f"split={ self ._split } , fold ={ self ._fold } "
92+ return f"split={ self ._split } , partition ={ self ._partition } "
8293
8394 def _check_exists (self ) -> bool :
8495 return os .path .exists (self ._data_folder ) and os .path .isdir (self ._data_folder )
0 commit comments