Skip to content

Commit

Permalink
[Cherry-pick for 0.15.2] move parameter sampling of RandomPhotometric…
Browse files Browse the repository at this point in the history
…Distort into _get_params (#7444)

Co-authored-by: Philip Meier <[email protected]>
  • Loading branch information
NicolasHug and pmeier authored Mar 23, 2023
1 parent 42759b1 commit e872006
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 29 deletions.
4 changes: 2 additions & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,7 +1876,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif data_augmentation == "ssd":
t = [
transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor,
Expand Down Expand Up @@ -1935,7 +1935,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# param is True.
# Note that the values below are probably specific to the random seed
# set above (which is fine).
(True, "ssd"): 4,
(True, "ssd"): 5,
(True, "ssdlite"): 4,
}.get((sanitize, data_augmentation), num_boxes)

Expand Down
49 changes: 22 additions & 27 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,19 +228,22 @@ def __init__(

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(
zip(
["brightness", "contrast1", "saturation", "hue", "contrast2"],
(torch.rand(5) < self.p).tolist(),
),
contrast_before=bool(torch.rand(()) < 0.5),
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
)
params: Dict[str, Any] = {
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
for key, range in [
("brightness_factor", self.brightness),
("contrast_factor", self.contrast),
("saturation_factor", self.saturation),
("hue_factor", self.hue),
]
}
params["contrast_before"] = bool(torch.rand(()) < 0.5)
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params

def _permute_channels(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
) -> Union[datapoints._ImageType, datapoints._VideoType]:

orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
Expand All @@ -256,24 +259,16 @@ def _permute_channels(
def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["brightness"]:
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
)
if params["contrast1"] and params["contrast_before"]:
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["saturation"]:
inpt = F.adjust_saturation(
inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1])
)
if params["hue"]:
inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1]))
if params["contrast2"] and not params["contrast_before"]:
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["brightness_factor"] is not None:
inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["saturation_factor"] is not None:
inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"])
if params["hue_factor"] is not None:
inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"])
if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
return inpt
Expand Down

0 comments on commit e872006

Please sign in to comment.