Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions captum/optim/_param/image/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,23 +478,38 @@ class NaturalImage(ImageParameterization):
or rescaling to [0,255], it can perform those steps with the provided transforms or
inside its computation.
For example, our GoogleNet factory function has a `transform_input=True` argument.

Arguments:
size (list of int): The height and width to use for the nn.Parameter tensor.
channels (list of int): The number of channels to use when creating the
nn.Parameter tensor.
batch (list of int): The number of channels to use when creating the
nn.Parameter tensor, or stacking init images.
parameterization (ImageParameterization): An image parameterization class.
squash_func (SquashFunc): The squash function to use after
color recorrelation. A funtion or lambda function.
decorrelation_module (nn.Module): A ToRGB instance.
decorrelate_init (bool): Whether or not to apply color decorrelation to the
init tensor input.
"""

def __init__(
self,
size: InitSize = None,
channels: int = 3,
batch: int = 1,
parameterization: ImageParameterization = FFTImage,
init: Optional[torch.Tensor] = None,
decorrelate_init: bool = True,
parameterization: ImageParameterization = FFTImage,
squash_func: Optional[SquashFunc] = None,
decorrelation_module: nn.Module = ToRGB(transform_matrix="klt"),
decorrelate_init: bool = True,
) -> None:
super().__init__()
self.decorrelate = ToRGB(transform_name="klt")
self.decorrelate = decorrelation_module
if init is not None:
assert init.dim() == 3 or init.dim() == 4
if decorrelate_init:
assert self.decorrelate is not None
init = (
init.refine_names("B", "C", "H", "W")
if init.dim() == 4
Expand All @@ -513,6 +528,7 @@ def __init__(

def forward(self) -> torch.Tensor:
image = self.parameterization()
image = self.decorrelate(image)
if self.decorrelate is not None:
image = self.decorrelate(image)
image = image.rename(None) # TODO: the world is not yet ready
return CudaImageTensor(self.squash_func(image))
92 changes: 53 additions & 39 deletions captum/optim/_param/image/transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import numbers
from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -73,15 +73,21 @@ def i1i2i3_transform() -> torch.Tensor:
]
return torch.Tensor(i1i2i3_matrix)

def __init__(self, transform_name: str = "klt") -> None:
def __init__(self, transform_matrix: Union[str, torch.Tensor] = "klt") -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the description of the PR it is mentioned that we can disable transform_matrix in NaturalImage - do you mean that we can default to it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ToRGB transform can optionally be turned off if one does not wish to use color decorrelation (like in Lucid & Lucent). The transform variable (previously called transform_matrix) can the matrix created by the dataset KLT function located in _utils/image/dataset, or it can be from something else. If you don't specify a matrix tensor, then you can use the default KLT or i1i2i3 one.

super().__init__()

if transform_name == "klt":
assert isinstance(transform_matrix, str) or torch.is_tensor(transform_matrix)
if torch.is_tensor(transform_matrix):
assert list(transform_matrix.shape) == [3, 3]
self.register_buffer("transform", transform_matrix)
elif transform_matrix == "klt":
self.register_buffer("transform", ToRGB.klt_transform())
elif transform_name == "i1i2i3":
elif transform_matrix == "i1i2i3":
self.register_buffer("transform", ToRGB.i1i2i3_transform())
else:
raise ValueError("transform_name has to be either 'klt' or 'i1i2i3'")
raise ValueError(
"transform_matrix has to be either 'klt', 'i1i2i3',"
+ " or a matrix tensor."
)

def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:
assert x.dim() == 3 or x.dim() == 4
Expand Down Expand Up @@ -118,57 +124,65 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:

class CenterCrop(torch.nn.Module):
"""
Center crop the specified amount of pixels from the edges.
Center crop a specified amount from a tensor
Arguments:
size (int, sequence) or (int): Number of pixels to center crop away.
pixels_from_edges (bool): Whether to treat crop size values as the number
of pixels from the tensor's edge, or an exact shape in the center.
"""

def __init__(self, size: TransformSize = 0) -> None:
def __init__(self, size: TransformSize = 0, pixels_from_edges: bool = True) -> None:
super(CenterCrop, self).__init__()
if type(size) is list or type(size) is tuple:
assert len(size) == 2, (
"CenterCrop requires a single crop value or a tuple of (height,width)"
+ "in pixels for cropping."
)
self.crop_val = size
else:
self.crop_val = [size] * 2
self.crop_vals = size
self.pixels_from_edges = pixels_from_edges

def forward(self, input: torch.Tensor) -> torch.Tensor:
assert (
input.dim() == 3 or input.dim() == 4
), "Input to CenterCrop must be 3D or 4D"
if input.dim() == 4:
h, w = input.size(2), input.size(3)
elif input.dim() == 3:
h, w = input.size(1), input.size(2)
h_crop = h - self.crop_val[0]
w_crop = w - self.crop_val[1]
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
return input[..., sh : sh + h_crop, sw : sw + w_crop]
"""
Center crop an input.
Arguments:
input (torch.Tensor): Input to center crop.
Returns:
tensor (torch.Tensor): A center cropped tensor.
"""

return center_crop(input, self.crop_vals, self.pixels_from_edges)

def center_crop_shape(input: torch.Tensor, output_size: List[int]) -> torch.Tensor:

def center_crop(
input: torch.Tensor, crop_vals: TransformSize, pixels_from_edges: bool = True
) -> torch.Tensor:
"""
Crop NCHW & CHW outputs by specifying the desired output shape.
Center crop a specified amount from a tensor
Arguments:
input (tensor): A CHW or NCHW image tensor to center crop.
size (int, sequence) or (int): Number of pixels to center crop away.
pixels_from_edges (bool): Whether to treat crop size values as the number
of pixels from the tensor's edge, or an exact shape in the center.
Returns:
*tensor*: A center cropped tensor.
"""

assert input.dim() == 4 or input.dim() == 3
output_size = [output_size] if not hasattr(output_size, "__iter__") else output_size
assert len(output_size) == 1 or len(output_size) == 2
output_size = output_size * 2 if len(output_size) == 1 else output_size
assert input.dim() == 3 or input.dim() == 4
crop_vals = [crop_vals] if not hasattr(crop_vals, "__iter__") else crop_vals
crop_vals = cast(Union[List[int], Tuple[int], Tuple[int, int]], crop_vals)
assert len(crop_vals) == 1 or len(crop_vals) == 2
crop_vals = crop_vals * 2 if len(crop_vals) == 1 else crop_vals

if input.dim() == 4:
h, w = input.size(2), input.size(3)
if input.dim() == 3:
h, w = input.size(1), input.size(2)

h_crop = h - int(round((h - output_size[0]) / 2.0))
w_crop = w - int(round((w - output_size[1]) / 2.0))

return input[
..., h_crop - output_size[0] : h_crop, w_crop - output_size[1] : w_crop
]
if pixels_from_edges:
h_crop = h - crop_vals[0]
w_crop = w - crop_vals[1]
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
x = input[..., sh : sh + h_crop, sw : sw + w_crop]
else:
h_crop = h - int(round((h - crop_vals[0]) / 2.0))
w_crop = w - int(round((w - crop_vals[1]) / 2.0))
x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop]
return x


def rand_select(transform_values: TransformValList) -> TransformVal:
Expand Down
8 changes: 4 additions & 4 deletions captum/optim/_utils/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn

from captum.optim._param.image.transform import center_crop_shape
from captum.optim._param.image.transform import center_crop
from captum.optim._utils.models import collect_activations
from captum.optim._utils.typing import ModelInputType, TransformSize

Expand Down Expand Up @@ -56,8 +56,8 @@ def get_expanded_weights(
retain_graph=True,
)[0]
A.append(x.squeeze(0))
exapnded_weights = torch.stack(A, 0)
expanded_weights = torch.stack(A, 0)

if crop_shape is not None:
exapnded_weights = center_crop_shape(exapnded_weights, crop_shape)
return exapnded_weights
expanded_weights = center_crop(expanded_weights, crop_shape, False)
return expanded_weights
2 changes: 1 addition & 1 deletion captum/optim/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ def cleanup(self):
SquashFunc = Callable[[Tensor], Tensor]
TransformValList = Union[Sequence[int], Sequence[float], Tensor]
TransformVal = Union[int, float, Tensor]
TransformSize = Union[List[int], Tuple[int], int]
TransformSize = Union[List[int], Tuple[int], Tuple[int, int], int]
ModelInputType = Union[Tuple[Tensor], Tensor]
2 changes: 1 addition & 1 deletion tests/optim/helpers/numpy_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def setup_batch(x: np.ndarray, batch: int = 1, dim: int = 3) -> np.ndarray:
return x


class FFTImage(object):
class FFTImage:
"""Parameterize an image using inverse real 2D FFT"""

def __init__(
Expand Down
102 changes: 73 additions & 29 deletions tests/optim/helpers/numpy_transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Optional
from typing import List, Optional, Tuple, Union, cast

import numpy as np

from captum.optim._utils.typing import TransformSize


class BlendAlpha(object):
"""
Expand Down Expand Up @@ -53,33 +55,67 @@ def jitter(self, x: np.ndarray) -> np.ndarray:
return self.translate_array(x, insets)


class CenterCrop(object):
class CenterCrop:
"""
NumPy version of the CenterCrop transform
Center crop a specified amount from a tensor.
Arguments:
size (int, sequence) or (int): Number of pixels to center crop away.
pixels_from_edges (bool): Whether to treat crop size values as the number
of pixels from the tensor's edge, or an exact shape in the center.
"""

def __init__(self, size=0) -> None:
super().__init__()
if type(size) is list or type(size) is tuple:
assert len(size) == 2, (
"CenterCrop requires a single crop value or a tuple of (height,width)"
+ "in pixels for cropping."
)
self.crop_val = size
else:
self.crop_val = [size] * 2
assert len(self.crop_val) == 2

def crop(self, input: np.ndarray) -> np.ndarray:
assert input.ndim == 3 or input.ndim == 4
if input.ndim == 4:
h, w = input.shape[2], input.shape[3]
elif input.ndim == 3:
h, w = input.shape[1], input.shape[2]
h_crop = h - self.crop_val[0]
w_crop = w - self.crop_val[1]
def __init__(self, size: TransformSize = 0, pixels_from_edges: bool = True) -> None:
super(CenterCrop, self).__init__()
self.crop_vals = size
self.pixels_from_edges = pixels_from_edges

def forward(self, input: np.ndarray) -> np.ndarray:
"""
Center crop an input.
Arguments:
input (array): Input to center crop.
Returns:
tensor (array): A center cropped tensor.
"""

return center_crop(input, self.crop_vals, self.pixels_from_edges)


def center_crop(
input: np.ndarray, crop_vals: TransformSize, pixels_from_edges: bool = True
) -> np.ndarray:
"""
Center crop a specified amount from a array.
Arguments:
input (array): A CHW or NCHW image array to center crop.
size (int, sequence) or (int): Number of pixels to center crop away.
pixels_from_edges (bool): Whether to treat crop size values as the number
of pixels from the array's edge, or an exact shape in the center.
Returns:
*array*: A center cropped array.
"""

assert input.ndim == 3 or input.ndim == 4
crop_vals = [crop_vals] if not hasattr(crop_vals, "__iter__") else crop_vals
crop_vals = cast(Union[List[int], Tuple[int], Tuple[int, int]], crop_vals)
assert len(crop_vals) == 1 or len(crop_vals) == 2
crop_vals = crop_vals * 2 if len(crop_vals) == 1 else crop_vals

if input.ndim == 4:
h, w = input.shape[2], input.shape[3]
if input.ndim == 3:
h, w = input.shape[1], input.shape[2]

if pixels_from_edges:
h_crop = h - crop_vals[0]
w_crop = w - crop_vals[1]
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
return input[..., sh : sh + h_crop, sw : sw + w_crop]
x = input[..., sh : sh + h_crop, sw : sw + w_crop]
else:
h_crop = h - int(round((h - crop_vals[0]) / 2.0))
w_crop = w - int(round((w - crop_vals[1]) / 2.0))
x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop]
return x


class ToRGB(object):
Expand All @@ -103,15 +139,23 @@ def i1i2i3_transform() -> np.ndarray:
]
return np.array(i1i2i3_matrix, dtype=float)

def __init__(self, transform_name: str = "klt") -> None:
def __init__(self, transform_matrix: Union[str, np.ndarray] = "klt") -> None:
super().__init__()

if transform_name == "klt":
assert isinstance(transform_matrix, str) or isinstance(
transform_matrix, np.ndarray
)
if isinstance(transform_matrix, np.ndarray):
assert list(transform_matrix.shape) == [3, 3]
self.transform = transform_matrix
elif transform_matrix == "klt":
self.transform = ToRGB.klt_transform()
elif transform_name == "i1i2i3":
elif transform_matrix == "i1i2i3":
self.transform = ToRGB.i1i2i3_transform()
else:
raise ValueError("transform_name has to be either 'klt' or 'i1i2i3'")
raise ValueError(
"transform_matrix has to be either 'klt', 'i1i2i3',"
+ " or a matrix array."
)

def to_rgb(self, x: np.ndarray, inverse: bool = False) -> np.ndarray:
assert x.ndim == 3 or x.ndim == 4
Expand Down
Loading