Skip to content

Commit d2583d4

Browse files
authored
Add asserts & more tests for RandomScale transforms
1 parent f862df8 commit d2583d4

File tree

2 files changed

+524
-72
lines changed

2 files changed

+524
-72
lines changed

captum/optim/_param/image/transforms.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,10 @@ class RandomScale(nn.Module):
358358
"scale",
359359
"mode",
360360
"align_corners",
361-
"has_align_corners",
361+
"_has_align_corners",
362362
"recompute_scale_factor",
363-
"has_recompute_scale_factor",
364-
"is_distribution",
363+
"_has_recompute_scale_factor",
364+
"_is_distribution",
365365
]
366366

367367
def __init__(
@@ -390,8 +390,9 @@ def __init__(
390390
super().__init__()
391391
if isinstance(scale, torch.distributions.distribution.Distribution):
392392
# Distributions are not supported by TorchScript / JIT yet
393+
assert scale.batch_shape == torch.Size([])
393394
self.scale_distribution = scale
394-
self.is_distribution = True
395+
self._is_distribution = True
395396
self.scale = []
396397
else:
397398
assert hasattr(scale, "__iter__")
@@ -400,33 +401,12 @@ def __init__(
400401
scale = scale.tolist()
401402
assert len(scale) > 0
402403
self.scale = [float(s) for s in scale]
403-
self.is_distribution = False
404+
self._is_distribution = False
404405
self.mode = mode
405406
self.align_corners = align_corners if mode not in ["nearest", "area"] else None
406407
self.recompute_scale_factor = recompute_scale_factor
407-
self.has_align_corners = torch.__version__ >= "1.3.0"
408-
self.has_recompute_scale_factor = torch.__version__ >= "1.6.0"
409-
410-
def _get_scale_mat(
411-
self,
412-
m: float,
413-
device: torch.device,
414-
dtype: torch.dtype,
415-
) -> torch.Tensor:
416-
"""
417-
Create a scale matrix tensor.
418-
419-
Args:
420-
421-
m (float): The scale value to use.
422-
423-
Returns:
424-
**scale_mat** (torch.Tensor): A scale matrix.
425-
"""
426-
scale_mat = torch.tensor(
427-
[[m, 0.0, 0.0], [0.0, m, 0.0]], device=device, dtype=dtype
428-
)
429-
return scale_mat
408+
self._has_align_corners = torch.__version__ >= "1.3.0"
409+
self._has_recompute_scale_factor = torch.__version__ >= "1.6.0"
430410

431411
def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
432412
"""
@@ -440,8 +420,8 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
440420
Returns:
441421
**x** (torch.Tensor): A scaled NCHW image tensor.
442422
"""
443-
if self.has_align_corners:
444-
if self.has_recompute_scale_factor:
423+
if self._has_align_corners:
424+
if self._has_recompute_scale_factor:
445425
x = F.interpolate(
446426
x,
447427
scale_factor=scale,
@@ -472,8 +452,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
472452
**x** (torch.Tensor): A randomly scaled NCHW image *tensor*.
473453
"""
474454
assert x.dim() == 4
475-
if self.is_distribution:
476-
scale = self.scale_distribution.sample().item()
455+
if self._is_distribution:
456+
scale = float(self.scale_distribution.sample().item())
477457
else:
478458
n = int(
479459
torch.randint(
@@ -508,8 +488,8 @@ class RandomScaleAffine(nn.Module):
508488
"mode",
509489
"padding_mode",
510490
"align_corners",
511-
"has_align_corners",
512-
"is_distribution",
491+
"_has_align_corners",
492+
"_is_distribution",
513493
]
514494

515495
def __init__(
@@ -539,8 +519,9 @@ def __init__(
539519
super().__init__()
540520
if isinstance(scale, torch.distributions.distribution.Distribution):
541521
# Distributions are not supported by TorchScript / JIT yet
522+
assert scale.batch_shape == torch.Size([])
542523
self.scale_distribution = scale
543-
self.is_distribution = True
524+
self._is_distribution = True
544525
self.scale = []
545526
else:
546527
assert hasattr(scale, "__iter__")
@@ -549,11 +530,11 @@ def __init__(
549530
scale = scale.tolist()
550531
assert len(scale) > 0
551532
self.scale = [float(s) for s in scale]
552-
self.is_distribution = False
533+
self._is_distribution = False
553534
self.mode = mode
554535
self.padding_mode = padding_mode
555536
self.align_corners = align_corners
556-
self.has_align_corners = torch.__version__ >= "1.3.0"
537+
self._has_align_corners = torch.__version__ >= "1.3.0"
557538

558539
def _get_scale_mat(
559540
self,
@@ -591,7 +572,7 @@ def _scale_tensor(self, x: torch.Tensor, scale: float) -> torch.Tensor:
591572
scale_matrix = self._get_scale_mat(scale, x.device, x.dtype)[None, ...].repeat(
592573
x.shape[0], 1, 1
593574
)
594-
if self.has_align_corners:
575+
if self._has_align_corners:
595576
# Pass align_corners explicitly for torch >= 1.3.0
596577
grid = F.affine_grid(
597578
scale_matrix, x.size(), align_corners=self.align_corners
@@ -620,8 +601,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
620601
**x** (torch.Tensor): A randomly scaled NCHW image *tensor*.
621602
"""
622603
assert x.dim() == 4
623-
if self.is_distribution:
624-
scale = self.scale_distribution.sample().item()
604+
if self._is_distribution:
605+
scale = float(self.scale_distribution.sample().item())
625606
else:
626607
n = int(
627608
torch.randint(
@@ -1021,7 +1002,7 @@ class TransformationRobustness(nn.Module):
10211002

10221003
def __init__(
10231004
self,
1024-
padding_transform: Optional[nn.Module] = None,
1005+
padding_transform: Optional[nn.Module] = nn.ConstantPad2d(2, value=0.5),
10251006
translate: Optional[Union[int, List[int]]] = [4] * 10,
10261007
scale: Optional[NumSeqOrTensorOrProbDistType] = [
10271008
0.995 ** n for n in range(-5, 80)
@@ -1039,7 +1020,7 @@ def __init__(
10391020
10401021
padding_transform (nn.Module, optional): A padding module instance. No
10411022
padding will be applied before transforms if set to None.
1042-
Default: None
1023+
Default: nn.ConstantPad2d(2, value=0.5)
10431024
translate (int or list of int, optional): The max horizontal and vertical
10441025
translation to use for each jitter transform.
10451026
Default: [4] * 10

0 commit comments

Comments
 (0)