Skip to content

Commit 6e18cea

Browse files
RichienbNicolasHug
andauthored
Add GaussianNoise transforms (#8381)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent b0f9f7b commit 6e18cea

File tree

6 files changed

+147
-2
lines changed

6 files changed

+147
-2
lines changed

docs/source/transforms.rst

+2
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ Color
350350
v2.RGB
351351
v2.RandomGrayscale
352352
v2.GaussianBlur
353+
v2.GaussianNoise
353354
v2.RandomInvert
354355
v2.RandomPosterize
355356
v2.RandomSolarize
@@ -368,6 +369,7 @@ Functionals
368369
v2.functional.grayscale_to_rgb
369370
v2.functional.to_grayscale
370371
v2.functional.gaussian_blur
372+
v2.functional.gaussian_noise
371373
v2.functional.invert
372374
v2.functional.posterize
373375
v2.functional.solarize

test/test_transforms_v2.py

+76-2
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,10 @@ def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):
111111

112112
input = input.as_subclass(torch.Tensor)
113113
with ignore_jit_no_profile_information_warning():
114-
actual = kernel_scripted(input, *args, **kwargs)
115-
expected = kernel(input, *args, **kwargs)
114+
with freeze_rng_state():
115+
actual = kernel_scripted(input, *args, **kwargs)
116+
with freeze_rng_state():
117+
expected = kernel(input, *args, **kwargs)
116118

117119
assert_close(actual, expected, rtol=rtol, atol=atol)
118120

@@ -3238,6 +3240,78 @@ def test_functional_image_correctness(self, dimensions, kernel_size, sigma, dtyp
32383240
torch.testing.assert_close(actual, expected, rtol=0, atol=1)
32393241

32403242

3243+
class TestGaussianNoise:
3244+
@pytest.mark.parametrize(
3245+
"make_input",
3246+
[make_image_tensor, make_image, make_video],
3247+
)
3248+
def test_kernel(self, make_input):
3249+
check_kernel(
3250+
F.gaussian_noise,
3251+
make_input(dtype=torch.float32),
3252+
# This cannot pass because the noise on a batch in not per-image
3253+
check_batched_vs_unbatched=False,
3254+
)
3255+
3256+
@pytest.mark.parametrize(
3257+
"make_input",
3258+
[make_image_tensor, make_image, make_video],
3259+
)
3260+
def test_functional(self, make_input):
3261+
check_functional(F.gaussian_noise, make_input(dtype=torch.float32))
3262+
3263+
@pytest.mark.parametrize(
3264+
("kernel", "input_type"),
3265+
[
3266+
(F.gaussian_noise, torch.Tensor),
3267+
(F.gaussian_noise_image, tv_tensors.Image),
3268+
(F.gaussian_noise_video, tv_tensors.Video),
3269+
],
3270+
)
3271+
def test_functional_signature(self, kernel, input_type):
3272+
check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type)
3273+
3274+
@pytest.mark.parametrize(
3275+
"make_input",
3276+
[make_image_tensor, make_image, make_video],
3277+
)
3278+
def test_transform(self, make_input):
3279+
def adapter(_, input, __):
3280+
# This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
3281+
# Same for PIL images
3282+
for key, value in input.items():
3283+
if isinstance(value, torch.Tensor) and not value.is_floating_point():
3284+
input[key] = value.to(torch.float32)
3285+
if isinstance(value, PIL.Image.Image):
3286+
input[key] = F.pil_to_tensor(value).to(torch.float32)
3287+
return input
3288+
3289+
check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter)
3290+
3291+
def test_bad_input(self):
3292+
with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."):
3293+
F.gaussian_noise(make_image_pil())
3294+
with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"):
3295+
F.gaussian_noise(make_image(dtype=torch.uint8))
3296+
with pytest.raises(ValueError, match="sigma shouldn't be negative"):
3297+
F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1)
3298+
3299+
def test_clip(self):
3300+
img = make_image(dtype=torch.float32)
3301+
3302+
out = F.gaussian_noise(img, mean=100, clip=False)
3303+
assert out.min() > 50
3304+
3305+
out = F.gaussian_noise(img, mean=100, clip=True)
3306+
assert (out == 1).all()
3307+
3308+
out = F.gaussian_noise(img, mean=-100, clip=False)
3309+
assert out.min() < -50
3310+
3311+
out = F.gaussian_noise(img, mean=-100, clip=True)
3312+
assert (out == 0).all()
3313+
3314+
32413315
class TestAutoAugmentTransforms:
32423316
# These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling.
32433317
# It's typically very hard to test the effect on some parameters without heavy mocking logic.

torchvision/transforms/v2/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ._misc import (
4646
ConvertImageDtype,
4747
GaussianBlur,
48+
GaussianNoise,
4849
Identity,
4950
Lambda,
5051
LinearTransformation,

torchvision/transforms/v2/_misc.py

+27
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,33 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
205205
return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params)
206206

207207

208+
class GaussianNoise(Transform):
209+
"""Add gaussian noise to images or videos.
210+
211+
The input tensor is expected to be in [..., 1 or 3, H, W] format,
212+
where ... means it can have an arbitrary number of leading dimensions.
213+
Each image or frame in a batch will be transformed independently i.e. the
214+
noise added to each image will be different.
215+
216+
The input tensor is also expected to be of float dtype in ``[0, 1]``.
217+
This transform does not support PIL images.
218+
219+
Args:
220+
mean (float): Mean of the sampled normal distribution. Default is 0.
221+
sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1.
222+
clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True.
223+
"""
224+
225+
def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None:
226+
super().__init__()
227+
self.mean = mean
228+
self.sigma = sigma
229+
self.clip = clip
230+
231+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
232+
return self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, sigma=self.sigma, clip=self.clip)
233+
234+
208235
class ToDtype(Transform):
209236
"""Converts the input to a specific dtype, optionally scaling the values for images or videos.
210237

torchvision/transforms/v2/functional/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@
136136
gaussian_blur,
137137
gaussian_blur_image,
138138
gaussian_blur_video,
139+
gaussian_noise,
140+
gaussian_noise_image,
141+
gaussian_noise_video,
139142
normalize,
140143
normalize_image,
141144
normalize_video,

torchvision/transforms/v2/functional/_misc.py

+38
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,44 @@ def gaussian_blur_video(
181181
return gaussian_blur_image(video, kernel_size, sigma)
182182

183183

184+
def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
185+
"""See :class:`~torchvision.transforms.v2.GaussianNoise`"""
186+
if torch.jit.is_scripting():
187+
return gaussian_noise_image(inpt, mean=mean, sigma=sigma)
188+
189+
_log_api_usage_once(gaussian_noise)
190+
191+
kernel = _get_kernel(gaussian_noise, type(inpt))
192+
return kernel(inpt, mean=mean, sigma=sigma, clip=clip)
193+
194+
195+
@_register_kernel_internal(gaussian_noise, torch.Tensor)
196+
@_register_kernel_internal(gaussian_noise, tv_tensors.Image)
197+
def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
198+
if not image.is_floating_point():
199+
raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}")
200+
if sigma < 0:
201+
raise ValueError(f"sigma shouldn't be negative. Got {sigma}")
202+
203+
noise = mean + torch.randn_like(image) * sigma
204+
out = image + noise
205+
if clip:
206+
out = torch.clamp(out, 0, 1)
207+
return out
208+
209+
210+
@_register_kernel_internal(gaussian_noise, tv_tensors.Video)
211+
def gaussian_noise_video(video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
212+
return gaussian_noise_image(video, mean=mean, sigma=sigma, clip=clip)
213+
214+
215+
@_register_kernel_internal(gaussian_noise, PIL.Image.Image)
216+
def _gaussian_noise_pil(
217+
video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True
218+
) -> PIL.Image.Image:
219+
raise ValueError("Gaussian Noise is not implemented for PIL images.")
220+
221+
184222
def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
185223
"""See :func:`~torchvision.transforms.v2.ToDtype` for details."""
186224
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)