Skip to content

Commit 1490767

Browse files
committed
remove or privatize functionality in features / datasets / transforms
1 parent f795349 commit 1490767

File tree

23 files changed

+46
-50
lines changed

23 files changed

+46
-50
lines changed

test/builtin_dataset_mocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
2020
from torch.nn.functional import one_hot
2121
from torch.testing import make_tensor as _make_tensor
22-
from torchvision.prototype.datasets._api import find
22+
from torchvision.prototype.datasets._api import _find
2323
from torchvision.prototype.utils._internal import sequence_to_str
2424

2525
make_tensor = functools.partial(_make_tensor, device="cpu")
@@ -31,7 +31,7 @@
3131

3232
class DatasetMock:
3333
def __init__(self, name, mock_data_fn):
34-
self.dataset = find(name)
34+
self.dataset = _find(name)
3535
self.info = self.dataset.info
3636
self.name = self.info.name
3737

torchvision/prototype/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
from ._home import home
1212

1313
# Load this last, since some parts depend on the above being loaded first
14-
from ._api import register, list_datasets, info, load # usort: skip
14+
from ._api import list_datasets, load # usort: skip
1515
from ._folder import from_data_folder, from_image_folder

torchvision/prototype/datasets/_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def list_datasets() -> List[str]:
2424
return sorted(DATASETS.keys())
2525

2626

27-
def find(name: str) -> Dataset:
27+
def _find(name: str) -> Dataset:
2828
name = name.lower()
2929
try:
3030
return DATASETS[name]
@@ -42,7 +42,7 @@ def find(name: str) -> Dataset:
4242

4343

4444
def info(name: str) -> DatasetInfo:
45-
return find(name).info
45+
return _find(name).info
4646

4747

4848
def load(
@@ -51,7 +51,7 @@ def load(
5151
skip_integrity_check: bool = False,
5252
**options: Any,
5353
) -> IterDataPipe[Dict[str, Any]]:
54-
dataset = find(name)
54+
dataset = _find(name)
5555

5656
config = dataset.info.make_config(**options)
5757
root = os.path.join(home(), dataset.name)

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
OnlineResource,
1818
)
1919
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
20-
from torchvision.prototype.features import Label, BoundingBox, Feature, EncodedImage
20+
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
2121

2222

2323
class Caltech101(Dataset):
@@ -95,7 +95,7 @@ def _prepare_sample(
9595
bounding_box=BoundingBox(
9696
ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size
9797
),
98-
contour=Feature(ann["obj_contour"].T),
98+
contour=_Feature(ann["obj_contour"].T),
9999
)
100100

101101
def _make_datapipe(

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
hint_sharding,
2424
hint_shuffling,
2525
)
26-
from torchvision.prototype.features import EncodedImage, Feature, Label, BoundingBox
26+
from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox
2727

2828

2929
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
@@ -140,7 +140,7 @@ def _prepare_sample(
140140
image_size=image.image_size,
141141
),
142142
landmarks={
143-
landmark: Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
143+
landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
144144
for landmark in {key[:-2] for key in landmarks.keys()}
145145
},
146146
)

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
hint_sharding,
3232
hint_shuffling,
3333
)
34-
from torchvision.prototype.features import BoundingBox, Label, Feature, EncodedImage
34+
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage
3535
from torchvision.prototype.utils._internal import FrozenMapping
3636

3737

@@ -95,16 +95,16 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
9595
labels = [ann["category_id"] for ann in anns]
9696
return dict(
9797
# TODO: create a segmentation feature
98-
segmentations=Feature(
98+
segmentations=_Feature(
9999
torch.stack(
100100
[
101101
self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size)
102102
for ann in anns
103103
]
104104
)
105105
),
106-
areas=Feature([ann["area"] for ann in anns]),
107-
crowds=Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool),
106+
areas=_Feature([ann["area"] for ann in anns]),
107+
crowds=_Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool),
108108
bounding_boxes=BoundingBox(
109109
[ann["bbox"] for ann in anns],
110110
format="xywh",

torchvision/prototype/datasets/_builtin/cub200.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
path_comparator,
3030
path_accessor,
3131
)
32-
from torchvision.prototype.features import Label, BoundingBox, Feature, EncodedImage
32+
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
3333

3434
csv.register_dialect("cub200", delimiter=" ")
3535

@@ -131,7 +131,7 @@ def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size:
131131
format="xyxy",
132132
image_size=image_size,
133133
),
134-
segmentation=Feature(content["seg"]),
134+
segmentation=_Feature(content["seg"]),
135135
)
136136

137137
def _prepare_sample(

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
hint_sharding,
2828
hint_shuffling,
2929
)
30-
from torchvision.prototype.features import Feature, EncodedImage
30+
from torchvision.prototype.features import _Feature, EncodedImage
3131

3232

3333
class SBD(Dataset):
@@ -81,8 +81,8 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
8181
image=EncodedImage.from_file(image_buffer),
8282
ann_path=ann_path,
8383
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
84-
boundaries=Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
85-
segmentation=Feature(anns["Segmentation"].item()),
84+
boundaries=_Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
85+
segmentation=_Feature(anns["Segmentation"].item()),
8686
)
8787

8888
def _make_datapipe(

torchvision/prototype/datasets/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
self.name = name
106106
self.variant = variant
107107

108-
self.new_raw_dataset = new_datasets._api.find(name)
108+
self.new_raw_dataset = new_datasets._api._find(name)
109109
self.legacy_cls = legacy_cls or self._find_legacy_cls()
110110

111111
if new_config is None:

torchvision/prototype/datasets/generate_category_files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77

88
from torchvision.prototype import datasets
9-
from torchvision.prototype.datasets._api import find
9+
from torchvision.prototype.datasets._api import _find
1010
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
1111

1212

@@ -18,7 +18,7 @@ def main(*names, force=False):
1818
if path.exists() and not force:
1919
continue
2020

21-
dataset = find(name)
21+
dataset = _find(name)
2222
try:
2323
categories = dataset._generate_categories(home / name)
2424
except NotImplementedError:

0 commit comments

Comments
 (0)