Skip to content

Commit

Permalink
Merge branch 'main' into references/recipe_detection
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox authored Apr 1, 2022
2 parents 8455b37 + ec1c2a1 commit 162f76d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._color import ColorJitter
from ._color import ColorJitter, RandomPhotometricDistort
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
Resize,
Expand Down
68 changes: 67 additions & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms import functional as _F

from ._utils import is_simple_tensor
from ._utils import is_simple_tensor, get_image_dimensions, query_image

T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)

Expand Down Expand Up @@ -120,5 +121,70 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:

for transform in params["image_transforms"]:
input = transform(input)
return input


class _RandomChannelShuffle(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels, _, _ = get_image_dimensions(image)
return dict(permutation=torch.randperm(num_channels))

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)):
return input

image = input
if isinstance(input, PIL.Image.Image):
image = _F.pil_to_tensor(image)

output = image[..., params["permutation"], :, :]

if isinstance(input, features.Image):
output = features.Image.new_like(input, output, color_space=features.ColorSpace.OTHER)
elif isinstance(input, PIL.Image.Image):
output = _F.to_pil_image(output)

return output


class RandomPhotometricDistort(Transform):
def __init__(
self,
contrast: Tuple[float, float] = (0.5, 1.5),
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
brightness: Tuple[float, float] = (0.875, 1.125),
p: float = 0.5,
):
super().__init__()
self._brightness = ColorJitter(brightness=brightness)
self._contrast = ColorJitter(contrast=contrast)
self._hue = ColorJitter(hue=hue)
self._saturation = ColorJitter(saturation=saturation)
self._channel_shuffle = _RandomChannelShuffle()
self.p = p

def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(
zip(
["brightness", "contrast1", "saturation", "hue", "contrast2", "channel_shuffle"],
torch.rand(6) < self.p,
),
contrast_before=torch.rand(()) < 0.5,
)

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if params["brightness"]:
input = self._brightness(input)
if params["contrast1"] and params["contrast_before"]:
input = self._contrast(input)
if params["saturation"]:
input = self._saturation(input)
if params["saturation"]:
input = self._saturation(input)
if params["contrast2"] and not params["contrast_before"]:
input = self._contrast(input)
if params["channel_shuffle"]:
input = self._channel_shuffle(input)
return input

0 comments on commit 162f76d

Please sign in to comment.