From 9bb7e1831b0cb11717a7f6dc67a9ba5fbe46d163 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 27 Feb 2023 13:29:19 +0100 Subject: [PATCH 1/4] add docstring for dataset wrapper (#7333) Co-authored-by: Nicolas Hug --- docs/source/datasets.rst | 9 +++ torchvision/datapoints/_dataset_wrapper.py | 65 +++++++++++++++++++++- 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 68c72e7af8c..35e5eaf2a9f 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -169,3 +169,12 @@ Base classes for custom datasets DatasetFolder ImageFolder VisionDataset + +Transforms v2 +------------- + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + wrap_dataset_for_transforms_v2 diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index e358c83d9d1..87ce3ba93a1 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -14,8 +14,71 @@ __all__ = ["wrap_dataset_for_transforms_v2"] -# TODO: naming! def wrap_dataset_for_transforms_v2(dataset): + """[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`. + + .. v2betastatus:: wrap_dataset_for_transforms_v2 function + + Example: + >>> dataset = torchvision.datasets.CocoDetection(...) + >>> dataset = wrap_dataset_for_transforms_v2(dataset) + + .. note:: + + For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset + configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you + to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so. + + The dataset samples are wrapped according to the description below. + + Special cases: + + * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper + returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), + ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. + The original keys are preserved. + * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to + the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are + preserved. + * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. + * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict + of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data + in the corresponding ``torchvision.datapoints``. The original keys are preserved. + * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a + :class:`~torchvision.datapoints.Mask` datapoint. + * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a + :class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by + a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and + ``"labels"``. + * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY`` + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. + + Image classification datasets + + This wrapper is a no-op for image classification datasets, since they were already fully supported by + :mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`. + + Segmentation datasets + + Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of + :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the + segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item). + + Video classification datasets + + Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a + :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a + :class:`~torchvision.datapoints.Video` while leaving the other items as is. + + .. note:: + + Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative + ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`. + + Args: + dataset: the dataset instance to wrap for compatibility with transforms v2. + """ return VisionDatasetDatapointWrapper(dataset) From 037c00620c95fb2eac14aa3d54b18726d8a50f54 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 2 Mar 2023 17:13:58 +0000 Subject: [PATCH 2/4] empty commit From fd11675bbe8e6487543c32869e7e65dacbddc74f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 22 Mar 2023 14:49:46 +0100 Subject: [PATCH 3/4] move parameter sampling of RandomPhotometricDistort into _get_params (#7442) --- torchvision/transforms/v2/_color.py | 49 +++++++++++++---------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 4ad534c988b..7dd8eeae236 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -228,19 +228,22 @@ def __init__( def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) - return dict( - zip( - ["brightness", "contrast1", "saturation", "hue", "contrast2"], - (torch.rand(5) < self.p).tolist(), - ), - contrast_before=bool(torch.rand(()) < 0.5), - channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, - ) + params: Dict[str, Any] = { + key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None + for key, range in [ + ("brightness_factor", self.brightness), + ("contrast_factor", self.contrast), + ("saturation_factor", self.saturation), + ("hue_factor", self.hue), + ] + } + params["contrast_before"] = bool(torch.rand(()) < 0.5) + params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None + return params def _permute_channels( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor ) -> Union[datapoints._ImageType, datapoints._VideoType]: - orig_inpt = inpt if isinstance(orig_inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) @@ -256,24 +259,16 @@ def _permute_channels( def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: - if params["brightness"]: - inpt = F.adjust_brightness( - inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) - ) - if params["contrast1"] and params["contrast_before"]: - inpt = F.adjust_contrast( - inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) - ) - if params["saturation"]: - inpt = F.adjust_saturation( - inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1]) - ) - if params["hue"]: - inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1])) - if params["contrast2"] and not params["contrast_before"]: - inpt = F.adjust_contrast( - inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) - ) + if params["brightness_factor"] is not None: + inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"]) + if params["contrast_factor"] is not None and params["contrast_before"]: + inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) + if params["saturation_factor"] is not None: + inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"]) + if params["hue_factor"] is not None: + inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"]) + if params["contrast_factor"] is not None and not params["contrast_before"]: + inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) if params["channel_permutation"] is not None: inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) return inpt From bf888d0ef4ee921421a1be409a4aed1f66b549a6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 23 Mar 2023 10:26:09 +0100 Subject: [PATCH 4/4] fix test_detection_preset for ssd data augmentation (#7447) --- test/test_transforms_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index e893858aff3..bca8925c71d 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1876,7 +1876,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): elif data_augmentation == "ssd": t = [ transforms.RandomPhotometricDistort(p=1), - transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})), + transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1), transforms.RandomIoUCrop(), transforms.RandomHorizontalFlip(p=1), to_tensor, @@ -1935,7 +1935,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): # param is True. # Note that the values below are probably specific to the random seed # set above (which is fine). - (True, "ssd"): 4, + (True, "ssd"): 5, (True, "ssdlite"): 4, }.get((sanitize, data_augmentation), num_boxes)