diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index e893858aff3..bca8925c71d 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1876,7 +1876,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): elif data_augmentation == "ssd": t = [ transforms.RandomPhotometricDistort(p=1), - transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})), + transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1), transforms.RandomIoUCrop(), transforms.RandomHorizontalFlip(p=1), to_tensor, @@ -1935,7 +1935,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): # param is True. # Note that the values below are probably specific to the random seed # set above (which is fine). - (True, "ssd"): 4, + (True, "ssd"): 5, (True, "ssdlite"): 4, }.get((sanitize, data_augmentation), num_boxes) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 4ad534c988b..7dd8eeae236 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -228,19 +228,22 @@ def __init__( def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) - return dict( - zip( - ["brightness", "contrast1", "saturation", "hue", "contrast2"], - (torch.rand(5) < self.p).tolist(), - ), - contrast_before=bool(torch.rand(()) < 0.5), - channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, - ) + params: Dict[str, Any] = { + key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None + for key, range in [ + ("brightness_factor", self.brightness), + ("contrast_factor", self.contrast), + ("saturation_factor", self.saturation), + ("hue_factor", self.hue), + ] + } + params["contrast_before"] = bool(torch.rand(()) < 0.5) + params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None + return params def _permute_channels( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor ) -> Union[datapoints._ImageType, datapoints._VideoType]: - orig_inpt = inpt if isinstance(orig_inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) @@ -256,24 +259,16 @@ def _permute_channels( def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: - if params["brightness"]: - inpt = F.adjust_brightness( - inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) - ) - if params["contrast1"] and params["contrast_before"]: - inpt = F.adjust_contrast( - inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) - ) - if params["saturation"]: - inpt = F.adjust_saturation( - inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1]) - ) - if params["hue"]: - inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1])) - if params["contrast2"] and not params["contrast_before"]: - inpt = F.adjust_contrast( - inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) - ) + if params["brightness_factor"] is not None: + inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"]) + if params["contrast_factor"] is not None and params["contrast_before"]: + inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) + if params["saturation_factor"] is not None: + inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"]) + if params["hue_factor"] is not None: + inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"]) + if params["contrast_factor"] is not None and not params["contrast_before"]: + inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) if params["channel_permutation"] is not None: inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) return inpt