diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index aa33793642..ad5b245c1c 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -448,14 +448,15 @@ def __init__( batch_index: Optional[int] = None, ) -> None: BaseLoss.__init__(self, target, batch_index) - self.direction = vec.reshape((1, -1, 1, 1)) + assert vec.dim() == 2 or vec.dim() == 4 + self.vec = vec.reshape((vec.size(0), -1, 1, 1)) if vec.dim() == 2 else vec self.cossim_pow = cossim_pow def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: activations = targets_to_values[self.target] - assert activations.size(1) == self.direction.size(1) + assert activations.size(1) == self.vec.size(1) activations = activations[self.batch_index[0] : self.batch_index[1]] - return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow) + return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow) @loss_wrapper @@ -477,7 +478,8 @@ def __init__( batch_index: Optional[int] = None, ) -> None: BaseLoss.__init__(self, target, batch_index) - self.direction = vec.reshape((1, -1, 1, 1)) + assert vec.dim() == 2 or vec.dim() == 4 + self.vec = vec.reshape((vec.size(0), -1, 1, 1)) if vec.dim() == 2 else vec self.x = x self.y = y self.channel_index = channel_index @@ -496,7 +498,81 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: ] if self.channel_index is not None: activations = activations[:, self.channel_index, ...][:, None, ...] - return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow) + return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow) + + +@loss_wrapper +class AngledNeuronDirection(BaseLoss): + """ + Visualize a direction vector with an optional whitened activation vector to + unstretch the activation space. Compared to the traditional Direction objectives, + this objective places more emphasis on angle by optionally multiplying the dot + product by the cosine similarity. + + When cossim_pow is equal to 0, this objective works as a euclidean + neuron objective. When cossim_pow is greater than 0, this objective works as a + cosine similarity objective. An additional whitened neuron direction vector + can optionally be supplied to improve visualization quality for some models. + + Carter, et al., "Activation Atlas", Distill, 2019. + https://distill.pub/2019/activation-atlas/ + Args: + target (nn.Module): A target layer instance. + vec (torch.Tensor): A neuron direction vector to use. + vec_whitened (torch.Tensor, optional): A whitened neuron direction vector. + cossim_pow (float, optional): The desired cosine similarity power to use. + x (int, optional): Optionally provide a specific x position for the target + neuron. + y (int, optional): Optionally provide a specific y position for the target + neuron. + eps (float, optional): If cossim_pow is greater than zero, the desired + epsilon value to use for cosine similarity calculations. + """ + + def __init__( + self, + target: torch.nn.Module, + vec: torch.Tensor, + vec_whitened: Optional[torch.Tensor] = None, + cossim_pow: float = 4.0, + x: Optional[int] = None, + y: Optional[int] = None, + eps: float = 1.0e-4, + batch_index: Optional[int] = None, + ) -> None: + BaseLoss.__init__(self, target, batch_index) + self.vec = vec.unsqueeze(0) if vec.dim() == 1 else vec + self.vec_whitened = vec_whitened + self.cossim_pow = cossim_pow + self.eps = eps + self.x = x + self.y = y + if self.vec_whitened is not None: + assert self.vec_whitened.dim() == 2 + assert self.vec.dim() == 2 + + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: + activations = targets_to_values[self.target] + activations = activations[self.batch_index[0] : self.batch_index[1]] + assert activations.dim() == 4 or activations.dim() == 2 + assert activations.shape[1] == self.vec.shape[1] + if activations.dim() == 4: + _x, _y = get_neuron_pos( + activations.size(2), activations.size(3), self.x, self.y + ) + activations = activations[..., _x, _y] + + vec = ( + torch.matmul(self.vec, self.vec_whitened)[0] + if self.vec_whitened is not None + else self.vec + ) + if self.cossim_pow == 0: + return activations * vec + + dot = torch.mean(activations * vec) + cossims = dot / (self.eps + torch.sqrt(torch.sum(activations ** 2))) + return dot * torch.clamp(cossims, min=0.1) ** self.cossim_pow @loss_wrapper @@ -515,7 +591,8 @@ def __init__( batch_index: Optional[int] = None, ) -> None: BaseLoss.__init__(self, target, batch_index) - self.direction = vec + assert vec.dim() == 4 + self.vec = vec self.cossim_pow = cossim_pow def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: @@ -523,8 +600,8 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: assert activations.dim() == 4 - H_direction, W_direction = self.direction.size(2), self.direction.size(3) - H_activ, W_activ = activations.size(2), activations.size(3) + H_direction, W_direction = self.vec.shape[2:] + H_activ, W_activ = activations.shape[2:] H = (H_activ - H_direction) // 2 W = (W_activ - W_direction) // 2 @@ -535,7 +612,7 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: H : H + H_direction, W : W + W_direction, ] - return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow) + return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow) @loss_wrapper @@ -617,6 +694,7 @@ def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor: "Alignment", "Direction", "NeuronDirection", + "AngledNeuronDirection", "TensorDirection", "ActivationWeights", "default_loss_summarize", diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index 93df78243e..1eac7af0cf 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -302,6 +302,10 @@ def __init__(self, scale: NumSeqOrTensorType) -> None: scale (float, sequence): Tuple of rescaling values to randomly select from. """ super().__init__() + assert hasattr(scale, "__iter__") + if torch.is_tensor(scale): + assert cast(torch.Tensor, scale).dim() == 1 + assert len(scale) > 0 self.scale = scale def get_scale_mat( @@ -384,6 +388,63 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return self.translate_tensor(input, insets) +class RandomRotation(nn.Module): + """ + Apply random rotation transforms on a NCHW tensor, using a sequence of degrees. + + Arguments: + degrees (float, sequence): Tuple, List, or Tensor of degrees to randomly + select from. + """ + + def __init__( + self, degrees: Union[List[float], Tuple[float, ...], torch.Tensor] + ) -> None: + super().__init__() + assert hasattr(degrees, "__iter__") + if torch.is_tensor(degrees): + assert cast(torch.Tensor, degrees).dim() == 1 + assert len(degrees) > 0 + self.degrees = degrees + + def get_rot_mat( + self, + theta: Union[int, float, torch.Tensor], + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + theta = torch.tensor(theta, device=device, dtype=dtype) + rot_mat = torch.tensor( + [ + [torch.cos(theta), -torch.sin(theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + ], + device=device, + dtype=dtype, + ) + return rot_mat + + def rotate_tensor( + self, x: torch.Tensor, theta: Union[int, float, torch.Tensor] + ) -> torch.Tensor: + theta = theta * math.pi / 180 + rot_matrix = self.get_rot_mat(theta, x.device, x.dtype)[None, ...].repeat( + x.shape[0], 1, 1 + ) + if torch.__version__ >= "1.3.0": + # Pass align_corners explicitly for torch >= 1.3.0 + grid = F.affine_grid(rot_matrix, x.size(), align_corners=False) + x = F.grid_sample(x, grid, align_corners=False) + else: + grid = F.affine_grid(rot_matrix, x.size()) + x = F.grid_sample(x, grid) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + rotate_angle = _rand_select(self.degrees) + return self.rotate_tensor(x, rotate_angle) + + class ScaleInputRange(nn.Module): """ Multiplies the input by a specified multiplier for models with input ranges other diff --git a/captum/optim/_utils/image/common.py b/captum/optim/_utils/image/common.py index ac6f7ea95d..e1a5413ef3 100644 --- a/captum/optim/_utils/image/common.py +++ b/captum/optim/_utils/image/common.py @@ -115,14 +115,14 @@ def nchannels_to_rgb(x: torch.Tensor, warp: bool = True) -> torch.Tensor: Convert an NCHW image with n channels into a 3 channel RGB image. Args: - x (torch.Tensor): Image tensor to transform into RGB image. + x (torch.Tensor): NCHW image tensor to transform into RGB image. warp (bool, optional): Whether or not to make colors more distinguishable. Default: True Returns: - *tensor* RGB image + tensor (torch.Tensor): An NCHW RGB image tensor. """ - def hue_to_rgb(angle: float) -> torch.Tensor: + def hue_to_rgb(angle: float, device: torch.device) -> torch.Tensor: """ Create an RGB unit vector based on a hue of the input angle. """ @@ -136,7 +136,8 @@ def hue_to_rgb(angle: float) -> torch.Tensor: [0.0, 0.7071, 0.7071], [0.0, 0.0, 1.0], [0.7071, 0.0, 0.7071], - ] + ], + device=device, ) idx = math.floor(angle / 60) @@ -161,9 +162,9 @@ def adj(x: float) -> float: nc = x.size(1) for i in range(nc): rgb = rgb + x[:, i][:, None, :, :] - rgb = rgb * hue_to_rgb(360 * i / nc).to(device=x.device)[None, :, None, None] + rgb = rgb * hue_to_rgb(360 * i / nc, device=x.device)[None, :, None, None] - rgb = rgb + torch.ones(x.size(2), x.size(3))[None, None, :, :] * ( + rgb = rgb + torch.ones(x.size(2), x.size(3), device=x.device)[None, None, :, :] * ( torch.sum(x, 1)[:, None] - torch.max(x, 1)[0][:, None] ) return (rgb / (1e-4 + torch.norm(rgb, dim=1, keepdim=True))) * torch.norm( @@ -172,25 +173,26 @@ def adj(x: float) -> float: def weights_to_heatmap_2d( - weight: torch.Tensor, + tensor: torch.Tensor, colors: List[str] = ["0571b0", "92c5de", "f7f7f7", "f4a582", "ca0020"], ) -> torch.Tensor: """ - Create a color heatmap of an input weight tensor. - By default red represents excitatory values, - blue represents inhibitory values, and white represents + Create a color heatmap of an input weight tensor. By default red represents + excitatory values, blue represents inhibitory values, and white represents no excitation or inhibition. Args: weight (torch.Tensor): A 2d tensor to create the heatmap from. - colors (List of strings): A list of strings containing color - hex values to use for coloring the heatmap. + colors (list of str): A list of 5 strings containing hex triplet + (six digit), three-byte hexadecimal color values to use for coloring + the heatmap. Returns: - *tensor*: A weight heatmap. + color_tensor (torch.Tensor): A weight heatmap. """ - assert weight.dim() == 2 + assert tensor.dim() == 2 assert len(colors) == 5 + assert all([len(c) == 6 for c in colors]) def get_color(x: str, device: torch.device = torch.device("cpu")) -> torch.Tensor: def hex2base10(x: str) -> float: @@ -200,31 +202,19 @@ def hex2base10(x: str) -> float: [hex2base10(x[0:2]), hex2base10(x[2:4]), hex2base10(x[4:6])], device=device ) - def color_scale(x: torch.Tensor) -> torch.Tensor: - if x < 0: - x = -x - if x < 0.5: - x = x * 2 - return (1 - x) * get_color(colors[2], x.device) + x * get_color( - colors[1], x.device - ) - else: - x = (x - 0.5) * 2 - return (1 - x) * get_color(colors[1], x.device) + x * get_color( - colors[0], x.device - ) - else: - if x < 0.5: - x = x * 2 - return (1 - x) * get_color(colors[2], x.device) + x * get_color( - colors[3], x.device - ) - else: - x = (x - 0.5) * 2 - return (1 - x) * get_color(colors[3], x.device) + x * get_color( - colors[4], x.device - ) - - return torch.stack( - [torch.stack([color_scale(x) for x in t]) for t in weight] + color_list = [get_color(c, tensor.device) for c in colors] + x = tensor.expand((3, tensor.shape[0], tensor.shape[1])).permute(1, 2, 0) + + color_tensor = ( + (x >= 0) * (x < 0.5) * ((1 - x * 2) * color_list[2] + x * 2 * color_list[3]) + + (x >= 0) + * (x >= 0.5) + * ((1 - (x - 0.5) * 2) * color_list[3] + (x - 0.5) * 2 * color_list[4]) + + (x < 0) + * (x > -0.5) + * ((1 - (-x * 2)) * color_list[2] + (-x * 2) * color_list[1]) + + (x < 0) + * (x <= -0.5) + * ((1 - (-x - 0.5) * 2) * color_list[1] + (-x - 0.5) * 2 * color_list[0]) ).permute(2, 0, 1) + return color_tensor diff --git a/captum/optim/_utils/image/dataset.py b/captum/optim/_utils/image/dataset.py index fcc6d03742..28bc84dd25 100644 --- a/captum/optim/_utils/image/dataset.py +++ b/captum/optim/_utils/image/dataset.py @@ -1,7 +1,17 @@ +from typing import cast + import torch +try: + from tqdm.auto import tqdm +except (ImportError, AssertionError): + print( + "The tqdm package is required to use captum.optim's" + + " image dataset functions with progress bar" + ) + -def image_cov(tensor: torch.Tensor) -> torch.Tensor: +def image_cov(x: torch.Tensor) -> torch.Tensor: """ Calculate a tensor's RGB covariance matrix. @@ -11,12 +21,17 @@ def image_cov(tensor: torch.Tensor) -> torch.Tensor: *tensor*: An RGB covariance matrix for the specified tensor. """ - tensor = tensor.reshape(-1, 3) - tensor = tensor - tensor.mean(0, keepdim=True) - return 1 / (tensor.size(0) - 1) * tensor.T @ tensor + assert x.dim() > 1 + x = x.reshape(-1, x.size(1)).T + x = x - torch.mean(x, dim=-1).unsqueeze(-1) + return 1 / (x.shape[-1] - 1) * x @ x.transpose(-1, -2) -def dataset_cov_matrix(loader: torch.utils.data.DataLoader) -> torch.Tensor: +def dataset_cov_matrix( + loader: torch.utils.data.DataLoader, + show_progress: bool = False, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: """ Calculate the covariance matrix for an image dataset. @@ -27,12 +42,21 @@ def dataset_cov_matrix(loader: torch.utils.data.DataLoader) -> torch.Tensor: *tensor*: A covariance matrix for the specified dataset. """ - cov_mtx = torch.zeros(3, 3) + if show_progress: + pbar = tqdm(total=len(loader.dataset), unit=" images") # type: ignore + + cov_mtx = cast(torch.Tensor, 0.0) for images, _ in loader: - assert images.dim() == 4 - for b in range(images.size(0)): - cov_mtx = cov_mtx + image_cov(images[b].permute(1, 2, 0)) - cov_mtx = cov_mtx / len(loader.dataset) # type: ignore + assert images.dim() > 1 + images = images.to(device) + cov_mtx = cov_mtx + image_cov(images) + if show_progress: + pbar.update(images.size(0)) + + if show_progress: + pbar.close() + + cov_mtx = cov_mtx / cast(int, len(loader.dataset)) return cov_mtx @@ -58,7 +82,10 @@ def cov_matrix_to_klt( def dataset_klt_matrix( - loader: torch.utils.data.DataLoader, normalize: bool = False + loader: torch.utils.data.DataLoader, + normalize: bool = False, + show_progress: bool = False, + device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """ Calculate the color correlation matrix, also known as @@ -74,5 +101,5 @@ def dataset_klt_matrix( *tensor*: A KLT matrix for the specified dataset. """ - cov_mtx = dataset_cov_matrix(loader) + cov_mtx = dataset_cov_matrix(loader, show_progress=show_progress, device=device) return cov_matrix_to_klt(cov_mtx, normalize) diff --git a/captum/optim/models/_common.py b/captum/optim/models/_common.py index e9fba1ba27..a94536d786 100644 --- a/captum/optim/models/_common.py +++ b/captum/optim/models/_common.py @@ -1,6 +1,6 @@ import math from inspect import signature -from typing import Dict, List, Tuple, Type, Union, cast +from typing import Dict, List, Optional, Tuple, Type, Union, cast import torch import torch.nn as nn @@ -254,3 +254,38 @@ class type to replace in the model. layers = cast(List[Type[nn.Module]], layers) for target_layer in layers: replace_layers(model, target_layer, SkipLayer) + + +class MaxPool2dRelaxed(torch.nn.Module): + """ + A relaxed pooling layer, that's useful for calculating attributions of spatial + positions. This layer reduces Noise in the gradient through the use of a + continuous relaxation of the gradient. + + Args: + kernel_size (int or tuple of int): The size of the window to perform max & + average pooling with. + stride (int or tuple of int, optional): The stride window size to use. + padding (int or tuple of int): The amount of zero padding to add to both sides + in the nn.MaxPool2d & nn.AvgPool2d modules. + ceil_mode (bool, optional): Whether to use ceil or floor for creating the + output shape. + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, ...]], + stride: Optional[Union[int, Tuple[int, ...]]] = None, + padding: Union[int, Tuple[int, ...]] = 0, + ceil_mode: bool = False, + ) -> None: + super().__init__() + self.maxpool = torch.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + self.avgpool = torch.nn.AvgPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.maxpool(x.detach()) + self.avgpool(x) - self.avgpool(x.detach()) diff --git a/setup.py b/setup.py index 64bdeee19a..19ceaa99ba 100755 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ def report(*args): "insights/widget/static", ] -TUTORIALS_REQUIRES = INSIGHTS_REQUIRES + ["torchtext", "torchvision"] +TUTORIALS_REQUIRES = INSIGHTS_REQUIRES + ["torchtext", "torchvision", "umap-learn"] TEST_REQUIRES = ["pytest", "pytest-cov"] diff --git a/tests/optim/core/test_loss.py b/tests/optim/core/test_loss.py index 566745de25..6f25273a80 100644 --- a/tests/optim/core/test_loss.py +++ b/tests/optim/core/test_loss.py @@ -133,6 +133,31 @@ def test_neuron_direction(self) -> None: self.assertAlmostEqual(get_loss_value(model, loss), dot, places=6) +class TestAngledNeuronDirection(BaseTest): + def test_angled_neuron_direction(self) -> None: + model = BasicModel_ConvNet_Optim() + loss = opt_loss.AngledNeuronDirection( + model.layer, vec=torch.ones(1, 2), cossim_pow=0 + ) + a = 1 + b = [CHANNEL_ACTIVATION_0_LOSS, CHANNEL_ACTIVATION_1_LOSS] + dot = np.sum(np.inner(a, b)) + self.assertAlmostEqual(np.sum(get_loss_value(model, loss)), dot, places=6) + + def test_angled_neuron_direction_whitened(self) -> None: + model = BasicModel_ConvNet_Optim() + loss = opt_loss.AngledNeuronDirection( + model.layer, + vec=torch.ones(1, 2), + vec_whitened=torch.ones(2, 2), + cossim_pow=0, + ) + a = 1 + b = [CHANNEL_ACTIVATION_0_LOSS, CHANNEL_ACTIVATION_1_LOSS] + dot = np.sum(np.inner(a, b)) * 2 + self.assertAlmostEqual(np.sum(get_loss_value(model, loss)), dot, places=6) + + class TestTensorDirection(BaseTest): def test_tensor_direction(self) -> None: model = BasicModel_ConvNet_Optim() diff --git a/tests/optim/helpers/numpy_common.py b/tests/optim/helpers/numpy_common.py deleted file mode 100644 index b432829694..0000000000 --- a/tests/optim/helpers/numpy_common.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import List - -import numpy as np - - -def weights_to_heatmap_2d( - array: np.ndarray, - colors: List[str] = ["0571b0", "92c5de", "f7f7f7", "f4a582", "ca0020"], -) -> np.ndarray: - """ - Create a color heatmap of an input weight array. - By default red represents excitatory values, - blue represents inhibitory values, and white represents - no excitation or inhibition. - - Args: - weight (array): A 2d array to create the heatmap from. - colors (List of strings): A list of strings containing color - hex values to use for coloring the heatmap. - Returns: - *array*: A weight heatmap. - """ - - assert array.ndim == 2 - assert len(colors) == 5 - - def get_color(x: str) -> np.ndarray: - def hex2base10(x: str) -> float: - return int(x, 16) / 255.0 - - return np.array([hex2base10(x[0:2]), hex2base10(x[2:4]), hex2base10(x[4:6])]) - - def color_scale(x: np.ndarray) -> np.ndarray: - if x < 0: - x = -x - if x < 0.5: - x = x * 2 - return (1 - x) * get_color(colors[2]) + x * get_color(colors[1]) - else: - x = (x - 0.5) * 2 - return (1 - x) * get_color(colors[1]) + x * get_color(colors[0]) - else: - if x < 0.5: - x = x * 2 - return (1 - x) * get_color(colors[2]) + x * get_color(colors[3]) - else: - x = (x - 0.5) * 2 - return (1 - x) * get_color(colors[3]) + x * get_color(colors[4]) - - return np.stack([np.stack([color_scale(x) for x in a]) for a in array]).transpose( - 2, 0, 1 - ) diff --git a/tests/optim/models/test_models_common.py b/tests/optim/models/test_models_common.py index f6418b8d6c..176b10fff2 100644 --- a/tests/optim/models/test_models_common.py +++ b/tests/optim/models/test_models_common.py @@ -290,3 +290,48 @@ def test_skip_layers(self) -> None: model_utils.skip_layers(model, torch.nn.ReLU) output_tensor = model(x) assertTensorAlmostEqual(self, x, output_tensor, 0) + + +class TestMaxPool2dRelaxed(BaseTest): + def test_maxpool2d_relaxed_forward_data(self) -> None: + maxpool_relaxed = model_utils.MaxPool2dRelaxed( + kernel_size=3, stride=2, padding=0, ceil_mode=True + ) + maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) + + test_input = torch.arange(0, 1 * 3 * 8 * 8).view(1, 3, 8, 8).float() + + test_output_relaxed = maxpool_relaxed(test_input.clone()) + test_output_max = maxpool(test_input.clone()) + + assertTensorAlmostEqual(self, test_output_relaxed, test_output_max) + + def test_maxpool2d_relaxed_gradient(self) -> None: + maxpool_relaxed = model_utils.MaxPool2dRelaxed( + kernel_size=3, stride=2, padding=0, ceil_mode=True + ) + test_input = torch.nn.Parameter( + torch.arange(0, 1 * 1 * 4 * 4).view(1, 1, 4, 4).float() + ) + + test_output = maxpool_relaxed(test_input) + + output_grad = torch.autograd.grad( + outputs=[test_output], + inputs=[test_input], + grad_outputs=[test_output], + )[0] + + expected_output = torch.tensor( + [ + [ + [ + [1.1111, 1.1111, 2.9444, 1.8333], + [1.1111, 1.1111, 2.9444, 1.8333], + [3.4444, 3.4444, 9.0278, 5.5833], + [2.3333, 2.3333, 6.0833, 3.7500], + ] + ] + ], + ) + assertTensorAlmostEqual(self, output_grad, expected_output, 0.0005) diff --git a/tests/optim/param/test_transforms.py b/tests/optim/param/test_transforms.py index ade3ba37e7..ff509a532e 100644 --- a/tests/optim/param/test_transforms.py +++ b/tests/optim/param/test_transforms.py @@ -72,6 +72,68 @@ def test_random_scale_matrix(self) -> None: ) +class TestRandomRotation(BaseTest): + def test_random_rotation_degrees(self) -> None: + test_degrees = [0.0, 1.0, 2.0, 3.0, 4.0] + rot_mod = transforms.RandomRotation(test_degrees) + degrees = rot_mod.degrees + self.assertTrue(hasattr(degrees, "__iter__")) + self.assertEqual(degrees, test_degrees) + + def test_random_rotation_matrix(self) -> None: + theta = 25.1 + theta = theta * 3.141592653589793 / 180 + rot_mod = transforms.RandomRotation([25.1]) + rot_matrix = rot_mod.get_rot_mat( + theta, device=torch.device("cpu"), dtype=torch.float32 + ) + expected_matrix = torch.tensor( + [[0.9056, -0.4242, 0.0000], [0.4242, 0.9056, 0.0000]] + ) + + assertTensorAlmostEqual(self, rot_matrix, expected_matrix) + + def test_random_rotation_rotate_tensor(self) -> None: + rot_mod = transforms.RandomRotation([25.0]) + + test_input = torch.eye(4, 4).repeat(3, 1, 1).unsqueeze(0) + test_output = rot_mod.rotate_tensor(test_input, 25.0) + + expected_output = ( + torch.tensor( + [ + [0.1143, 0.0000, 0.0000, 0.0000], + [0.5258, 0.6198, 0.2157, 0.0000], + [0.0000, 0.2157, 0.6198, 0.5258], + [0.0000, 0.0000, 0.0000, 0.1143], + ] + ) + .repeat(3, 1, 1) + .unsqueeze(0) + ) + assertTensorAlmostEqual(self, test_output, expected_output, 0.005) + + def test_random_rotation_forward(self) -> None: + rotate_transform = transforms.RandomRotation(list(range(-25, 25))) + x = torch.ones(1, 3, 224, 224) + output = rotate_transform(x) + + self.assertEqual(output.shape, x.shape) + + def test_random_rotation_forward_cuda(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping RandomRotation forward CUDA test due to not supporting" + + " CUDA." + ) + rotate_transform = transforms.RandomRotation(list(range(-25, 25))) + x = torch.ones(1, 3, 224, 224).cuda() + output = rotate_transform(x) + + self.assertTrue(output.is_cuda) + self.assertEqual(output.shape, x.shape) + + class TestRandomSpatialJitter(BaseTest): def test_random_spatial_jitter_hw(self) -> None: translate_vals = [4, 4] diff --git a/tests/optim/utils/image/common.py b/tests/optim/utils/image/common.py index 156f4617d6..048f692fcf 100644 --- a/tests/optim/utils/image/common.py +++ b/tests/optim/utils/image/common.py @@ -5,7 +5,6 @@ import captum.optim._utils.image.common as common from tests.helpers.basic import BaseTest, assertTensorAlmostEqual -from tests.optim.helpers import numpy_common class TestGetNeuronPos(unittest.TestCase): @@ -40,13 +39,100 @@ def test_get_neuron_pos_none_y(self) -> None: class TestNChannelsToRGB(BaseTest): def test_nchannels_to_rgb_collapse(self) -> None: - test_input = torch.randn(1, 6, 224, 224) - test_output = common.nchannels_to_rgb(test_input) - self.assertEqual(list(test_output.size()), [1, 3, 224, 224]) + test_input = torch.arange(0, 1 * 4 * 4 * 4).view(1, 4, 4, 4).float() + test_output = common.nchannels_to_rgb(test_input, warp=True) + expected_output = torch.tensor( + [ + [ + [ + [31.6934, 32.6204, 33.5554, 34.4981], + [35.4482, 36.4053, 37.3690, 38.3390], + [39.3149, 40.2964, 41.2832, 42.2750], + [43.2715, 44.2725, 45.2776, 46.2866], + ], + [ + [20.6687, 21.5674, 22.4618, 23.3529], + [24.2417, 25.1290, 26.0154, 26.9013], + [27.7870, 28.6729, 29.5592, 30.4460], + [31.3335, 32.2217, 33.1109, 34.0009], + ], + [ + [46.3932, 47.4421, 48.5129, 49.6036], + [50.7125, 51.8380, 52.9788, 54.1335], + [55.3011, 56.4806, 57.6710, 58.8715], + [60.0815, 61.3001, 62.5268, 63.7611], + ], + ] + ] + ) + assertTensorAlmostEqual(self, test_output, expected_output, delta=0) + + def test_nchannels_to_rgb_collapse_warp_false(self) -> None: + test_input = torch.arange(0, 1 * 4 * 4 * 4).view(1, 4, 4, 4).float() + test_output = common.nchannels_to_rgb(test_input, warp=False) + expected_output = torch.tensor( + [ + [ + [ + [28.4279, 29.3496, 30.2753, 31.2053], + [32.1396, 33.0782, 34.0210, 34.9679], + [35.9188, 36.8736, 37.8322, 38.7943], + [39.7598, 40.7286, 41.7006, 42.6756], + ], + [ + [20.5599, 21.4595, 22.3544, 23.2459], + [24.1351, 25.0225, 25.9088, 26.7946], + [27.6801, 28.5657, 29.4515, 30.3378], + [31.2247, 32.1124, 33.0008, 33.8900], + ], + [ + [48.5092, 49.5791, 50.6723, 51.7866], + [52.9201, 54.0713, 55.2386, 56.4206], + [57.6164, 58.8246, 60.0444, 61.2749], + [62.5153, 63.7649, 65.0231, 66.2892], + ], + ] + ] + ) + assertTensorAlmostEqual(self, test_output, expected_output, delta=0.001) def test_nchannels_to_rgb_increase(self) -> None: - test_input = torch.randn(1, 2, 224, 224) + test_input = torch.arange(0, 1 * 2 * 4 * 4).view(1, 2, 4, 4).float() + test_output = common.nchannels_to_rgb(test_input, warp=True) + expected_output = torch.tensor( + [ + [ + [ + [0.0000, 0.9234, 1.7311, 2.4623], + [3.1419, 3.7855, 4.4036, 5.0033], + [5.5894, 6.1654, 6.7337, 7.2961], + [7.8540, 8.4083, 8.9597, 9.5089], + ], + [ + [11.3136, 12.0238, 12.7476, 13.4895], + [14.2500, 15.0278, 15.8210, 16.6277], + [17.4464, 18.2754, 19.1135, 19.9595], + [20.8124, 21.6714, 22.5357, 23.4049], + ], + [ + [11.3136, 12.0238, 12.7476, 13.4895], + [14.2500, 15.0278, 15.8210, 16.6277], + [17.4464, 18.2754, 19.1135, 19.9595], + [20.8124, 21.6714, 22.5357, 23.4049], + ], + ] + ] + ) + assertTensorAlmostEqual(self, test_output, expected_output, delta=0.001) + + def test_nchannels_to_rgb_cuda(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping nchannels_to_rgb CUDA test due to not supporting CUDA." + ) + test_input = torch.randn(1, 6, 224, 224).cuda() test_output = common.nchannels_to_rgb(test_input) + self.assertTrue(test_output.is_cuda) self.assertEqual(list(test_output.size()), [1, 3, 224, 224]) @@ -60,13 +146,38 @@ def test_weights_to_heatmap_2d(self) -> None: x[4:5, 0:4] = x[4:5, 0:4] * -0.8 x_out = common.weights_to_heatmap_2d(x) - x_out_np = numpy_common.weights_to_heatmap_2d(x.numpy()) - assertTensorAlmostEqual(self, x_out, torch.as_tensor(x_out_np).float()) + + x_out_expected = torch.tensor( + [ + [ + [0.9639, 0.9639, 0.9639, 0.9639], + [0.8580, 0.8580, 0.8580, 0.8580], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.8102, 0.8102, 0.8102, 0.8102], + [0.2408, 0.2408, 0.2408, 0.2408], + ], + [ + [0.8400, 0.8400, 0.8400, 0.8400], + [0.2588, 0.2588, 0.2588, 0.2588], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.8902, 0.8902, 0.8902, 0.8902], + [0.5749, 0.5749, 0.5749, 0.5749], + ], + [ + [0.7851, 0.7851, 0.7851, 0.7851], + [0.2792, 0.2792, 0.2792, 0.2792], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.9294, 0.9294, 0.9294, 0.9294], + [0.7624, 0.7624, 0.7624, 0.7624], + ], + ] + ) + assertTensorAlmostEqual(self, x_out, x_out_expected, delta=0.01) def test_weights_to_heatmap_2d_cuda(self) -> None: if not torch.cuda.is_available(): raise unittest.SkipTest( - "Skipping ImageTensor CUDA test due to not supporting CUDA." + "Skipping weights_to_heatmap_2d CUDA test due to not supporting CUDA." ) x = torch.ones(5, 4) x[0:1, 0:4] = x[0:1, 0:4] * 0.2 @@ -76,6 +187,31 @@ def test_weights_to_heatmap_2d_cuda(self) -> None: x[4:5, 0:4] = x[4:5, 0:4] * -0.8 x_out = common.weights_to_heatmap_2d(x.cuda()) - x_out_np = numpy_common.weights_to_heatmap_2d(x.numpy()) - assertTensorAlmostEqual(self, x_out, torch.as_tensor(x_out_np).float()) + + x_out_expected = torch.tensor( + [ + [ + [0.9639, 0.9639, 0.9639, 0.9639], + [0.8580, 0.8580, 0.8580, 0.8580], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.8102, 0.8102, 0.8102, 0.8102], + [0.2408, 0.2408, 0.2408, 0.2408], + ], + [ + [0.8400, 0.8400, 0.8400, 0.8400], + [0.2588, 0.2588, 0.2588, 0.2588], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.8902, 0.8902, 0.8902, 0.8902], + [0.5749, 0.5749, 0.5749, 0.5749], + ], + [ + [0.7851, 0.7851, 0.7851, 0.7851], + [0.2792, 0.2792, 0.2792, 0.2792], + [0.9686, 0.9686, 0.9686, 0.9686], + [0.9294, 0.9294, 0.9294, 0.9294], + [0.7624, 0.7624, 0.7624, 0.7624], + ], + ] + ) + assertTensorAlmostEqual(self, x_out, x_out_expected, delta=0.01) self.assertTrue(x_out.is_cuda) diff --git a/tests/optim/utils/image/dataset.py b/tests/optim/utils/image/dataset.py index c6c1581ac9..21e15d2cf9 100644 --- a/tests/optim/utils/image/dataset.py +++ b/tests/optim/utils/image/dataset.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 + import torch import captum.optim._utils.image.dataset as dataset_utils diff --git a/tutorials/optimviz/atlas/ActivationAtlasSampleCollection_OptimViz.ipynb b/tutorials/optimviz/atlas/ActivationAtlasSampleCollection_OptimViz.ipynb new file mode 100644 index 0000000000..7600394cac --- /dev/null +++ b/tutorials/optimviz/atlas/ActivationAtlasSampleCollection_OptimViz.ipynb @@ -0,0 +1,592 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "ActivationAtlasSampleCollection_OptimViz.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "KP2PKna21WLK" + }, + "source": [ + "# Collecting Samples for Activation Atlases with captum.optim\n", + "\n", + "This notebook demonstrates how to collect the activation and corresponding attribution samples required for [Activation Atlases](https://distill.pub/2019/activation-atlas/) for the InceptionV1 model imported from Caffe." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "v6T6jxWb4cil" + }, + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from typing import List, Optional, Tuple, cast\n", + "\n", + "import os\n", + "import torch\n", + "import torchvision\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "from captum.optim.models import googlenet\n", + "\n", + "import captum.optim as opt\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dtE-t6ZG0-sJ" + }, + "source": [ + "### Dataset Download & Setup \n", + "\n", + "To begin, we'll need to download and setup the image dataset that our model was trained on. You can download ImageNet's ILSVRC2012 dataset from the [ImageNet website](http://www.image-net.org/challenges/LSVRC/2012/) or via BitTorrent from [Academic Torrents](https://academictorrents.com/details/a306397ccf9c2ead27155983c254227c0fd938e2)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lDt-6WMp0qh3" + }, + "source": [ + "collect_attributions = True # Set to False for no attributions\n", + "\n", + "# Setup basic transforms\n", + "# The model has the normalization step in its internal transform_input\n", + "# function, so we don't need to normalize our inputs here.\n", + "transform_list = [\n", + " torchvision.transforms.Resize((224, 224)),\n", + " torchvision.transforms.ToTensor(),\n", + "]\n", + "transform_list = torchvision.transforms.Compose(transform_list)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i85yBIhL7owj" + }, + "source": [ + "To make it easier to load the ImageNet dataset, we can use [Torchvision](https://pytorch.org/vision/stable/datasets.html#imagenet)'s `torchvision.datasets.ImageNet` instead of the default `ImageFolder`." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3oRqxlMq7gJ4" + }, + "source": [ + "# Load the dataset\n", + "image_dataset = torchvision.datasets.ImageNet(\n", + " root=\"path/to/dataset\", split=\"train\", transform=transform_list\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "573290Fr8KN7" + }, + "source": [ + "Now we wrap our dataset in a `torch.utils.data.DataLoader` instance, and set the desired batch size." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "DUCfwsvR7iGC" + }, + "source": [ + "# Set desired batch size & load dataset with torch.utils.DataLoader\n", + "image_loader = torch.utils.data.DataLoader(\n", + " image_dataset,\n", + " batch_size=32,\n", + " shuffle=True,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4qfpBAPu18jv" + }, + "source": [ + "We load our model, then set the desired model target layers and corresponding file names." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qMViqsJ82Mcp" + }, + "source": [ + "# Model to collect samples from, what layers of the model to collect samples from,\n", + "# and the desired names to use for the target layers.\n", + "sample_model = (\n", + " googlenet(\n", + " pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True\n", + " )\n", + " .eval()\n", + " .to(device)\n", + ")\n", + "sample_targets = [sample_model.mixed4c_relu]\n", + "sample_target_names = [\"mixed4c_relu_samples\"]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Jl719nyZEGSt" + }, + "source": [ + "By default the activation samples will not have the right class attributions, so we remedy this by loading a second instance of our model. We then replace all `nn.MaxPool2d` layers in the second model instance with Captum's `MaxPool2dRelaxed` layer. The relaxed max pooling layer lets us estimate the sample class attributions by determining the rate at which increasing the neuron affects the output classes." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "A-VJyHRm1tqC" + }, + "source": [ + "# Optionally collect attributions from a copy of the first model that's\n", + "# been setup with relaxed pooling layers.\n", + "if collect_attributions:\n", + " sample_model_attr = (\n", + " googlenet(\n", + " pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True\n", + " )\n", + " .eval()\n", + " .to(device)\n", + " )\n", + " opt.models.replace_layers(\n", + " sample_model_attr,\n", + " torch.nn.MaxPool2d,\n", + " opt.models.MaxPool2dRelaxed,\n", + " transfer_vars=True,\n", + " )\n", + " sample_attr_targets = [sample_model_attr.mixed4c_relu]\n", + " sample_logit_target = sample_model_attr.fc\n", + "else:\n", + " sample_model_attr = None\n", + " sample_attr_targets = None\n", + " sample_logit_target = None" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "32zDGSR5-qDW" + }, + "source": [ + "With our dataset loaded and models ready to go, we can now start collecting our samples. To perform the sample collection, we define a function called `capture_activation_samples` to randomly sample an x and y position for every image for all specified target layers." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2YLBCYP0J4Gq" + }, + "source": [ + "def attribute_spatial_position(\n", + " target_activ: torch.Tensor,\n", + " logit_activ: torch.Tensor,\n", + " position_mask: torch.Tensor,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + "\n", + " logit_activ: Captured activations from the FC / logit layer.\n", + " target_activ: Captured activations from the target layer.\n", + " position_mask (torch.Tensor, optional): If using a batch size greater than\n", + " one, a mask is used to zero out all the non-target positions.\n", + "\n", + " Returns:\n", + " logit_attr (torch.Tensor): A sorted list of class attributions for the target\n", + " spatial positions.\n", + " \"\"\"\n", + "\n", + " assert target_activ.dim() == 2 or target_activ.dim() == 4\n", + " assert logit_activ.dim() == 2\n", + "\n", + " zeros = torch.nn.Parameter(torch.zeros_like(logit_activ))\n", + " target_zeros = target_activ * position_mask\n", + "\n", + " grad_one = torch.autograd.grad(\n", + " outputs=[logit_activ],\n", + " inputs=[target_activ],\n", + " grad_outputs=[zeros],\n", + " create_graph=True,\n", + " )\n", + " logit_attr = torch.autograd.grad(\n", + " outputs=grad_one,\n", + " inputs=[zeros],\n", + " grad_outputs=[target_zeros],\n", + " create_graph=True,\n", + " )[0]\n", + " return logit_attr\n", + "\n", + "\n", + "def capture_activation_samples(\n", + " loader: torch.utils.data.DataLoader,\n", + " model: torch.nn.Module,\n", + " targets: List[torch.nn.Module],\n", + " target_names: Optional[List[str]] = None,\n", + " sample_dir: str = \"\",\n", + " num_images: Optional[int] = None,\n", + " samples_per_image: int = 1,\n", + " input_device: torch.device = torch.device(\"cpu\"),\n", + " collect_attributions: bool = False,\n", + " attr_model: Optional[torch.nn.Module] = None,\n", + " attr_targets: Optional[List[torch.nn.Module]] = None,\n", + " logit_target: Optional[torch.nn.Module] = None,\n", + " show_progress: bool = False,\n", + "):\n", + " \"\"\"\n", + " Capture randomly sampled activations for an image dataset from one or multiple\n", + " target layers.\n", + "\n", + " Args:\n", + "\n", + " loader (torch.utils.data.DataLoader): A torch.utils.data.DataLoader\n", + " instance for an image dataset.\n", + " model (nn.Module): A PyTorch model instance.\n", + " targets (list of nn.Module): A list of layers to collect activation samples\n", + " from.\n", + " target_names (list of str, optional): A list of names to use when saving sample\n", + " tensors as files. Names will automatically be chosen if set to None.\n", + " Default: None\n", + " sample_dir (str): Path to where activation samples should be saved.\n", + " Default: \"\"\n", + " num_images (int, optional): How many images to collect samples from.\n", + " Default is to collect samples for every image in the dataset. Set to None\n", + " to collect samples from every image in the dataset.\n", + " Default: None\n", + " samples_per_image (int): How many samples to collect per image.\n", + " Default: 1\n", + " input_device (torch.device, optional): The device to use for model\n", + " inputs.\n", + " Default: torch.device(\"cpu\")\n", + " collect_attributions (bool, optional): Whether or not to collect attributions\n", + " for samples.\n", + " Default: False\n", + " attr_model (nn.Module, optional): A PyTorch model instance to use for\n", + " calculating sample attributions.\n", + " Default: None\n", + " attr_targets (list of nn.Module, optional): A list of attribution model layers\n", + " to collect attributions from. This should be the exact same as the targets\n", + " parameter, except for the attribution model.\n", + " Default: None\n", + " logit_target (nn.Module, optional): The final layer in the attribution model\n", + " that determines the classes. This parameter is only enabled if\n", + " collect_attributions is set to True.\n", + " Default: None\n", + " show_progress (bool, optional): Whether or not to show progress.\n", + " Default: False\n", + " \"\"\"\n", + "\n", + " if target_names is None:\n", + " target_names = [\"target\" + str(i) + \"_\" for i in range(len(targets))]\n", + "\n", + " assert len(target_names) == len(targets)\n", + " assert os.path.isdir(sample_dir)\n", + "\n", + " def random_sample(\n", + " activations: torch.Tensor,\n", + " ) -> Tuple[List[torch.Tensor], List[List[List[int]]]]:\n", + " \"\"\"\n", + " Randomly sample H & W dimensions of activations with 4 dimensions.\n", + " \"\"\"\n", + " assert activations.dim() == 4 or activations.dim() == 2\n", + "\n", + " activation_samples: List = []\n", + " position_list: List = []\n", + "\n", + " with torch.no_grad():\n", + " for i in range(samples_per_image):\n", + " sample_position_list: List = []\n", + " for b in range(activations.size(0)):\n", + " if activations.dim() == 4:\n", + " h, w = activations.shape[2:]\n", + " y = torch.randint(low=1, high=h - 1, size=[1])\n", + " x = torch.randint(low=1, high=w - 1, size=[1])\n", + " activ = activations[b, :, y, x]\n", + " sample_position_list.append((b, y, x))\n", + " elif activations.dim() == 2:\n", + " activ = activations[b].unsqueeze(1)\n", + " sample_position_list.append(b)\n", + " activation_samples.append(activ)\n", + " position_list.append(sample_position_list)\n", + " return activation_samples, position_list\n", + "\n", + " def attribute_samples(\n", + " activations: torch.Tensor,\n", + " logit_activ: torch.Tensor,\n", + " position_list: List[List[List[int]]],\n", + " ) -> List[torch.Tensor]:\n", + " \"\"\"\n", + " Collect attributions for target sample positions.\n", + " \"\"\"\n", + " assert activations.dim() == 4 or activations.dim() == 2\n", + "\n", + " sample_attributions: List = []\n", + " with torch.set_grad_enabled(True):\n", + " zeros_mask = torch.zeros_like(activations)\n", + " for sample_pos_list in position_list:\n", + " for c in sample_pos_list:\n", + " if activations.dim() == 4:\n", + " zeros_mask[c[0], :, c[1], c[2]] = 1\n", + " elif activations.dim() == 2:\n", + " zeros_mask[c] = 1\n", + " attr = attribute_spatial_position(\n", + " activations, logit_activ, position_mask=zeros_mask\n", + " ).detach()\n", + " sample_attributions.append(attr)\n", + " return sample_attributions\n", + "\n", + " if collect_attributions:\n", + " logit_target == list(model.children())[len(list(model.children())) - 1 :][\n", + " 0\n", + " ] if logit_target is None else logit_target\n", + " attr_targets = cast(List[torch.nn.Module], attr_targets)\n", + " attr_targets += [cast(torch.nn.Module, logit_target)]\n", + "\n", + " if show_progress:\n", + " total = (\n", + " len(loader.dataset) if num_images is None else num_images # type: ignore\n", + " )\n", + " pbar = tqdm(total=total, unit=\" images\")\n", + "\n", + " image_count, batch_count = 0, 0\n", + " with torch.no_grad():\n", + " for inputs, _ in loader:\n", + " inputs = inputs.to(input_device)\n", + " image_count += inputs.size(0)\n", + " batch_count += 1\n", + "\n", + " target_activ_dict = opt.models.collect_activations(model, targets, inputs)\n", + " if collect_attributions:\n", + " with torch.set_grad_enabled(True):\n", + " target_activ_attr_dict = opt.models.collect_activations(\n", + " attr_model, attr_targets, inputs\n", + " )\n", + " logit_activ = target_activ_attr_dict[logit_target]\n", + " del target_activ_attr_dict[logit_target]\n", + "\n", + " sample_coords = []\n", + " for t, n in zip(target_activ_dict, target_names):\n", + " sample_tensors, p_list = random_sample(target_activ_dict[t])\n", + " torch.save(\n", + " sample_tensors,\n", + " os.path.join(\n", + " sample_dir, n + \"_activations_\" + str(batch_count) + \".pt\"\n", + " ),\n", + " )\n", + " sample_coords.append(p_list)\n", + "\n", + " if collect_attributions:\n", + " for t, n, s_coords in zip(\n", + " target_activ_attr_dict, target_names, sample_coords\n", + " ):\n", + " sample_attrs = attribute_samples(\n", + " target_activ_attr_dict[t], logit_activ, s_coords\n", + " )\n", + " torch.save(\n", + " sample_attrs,\n", + " os.path.join(\n", + " sample_dir,\n", + " n + \"_attributions_\" + str(batch_count) + \".pt\",\n", + " ),\n", + " )\n", + "\n", + " if show_progress:\n", + " pbar.update(inputs.size(0))\n", + "\n", + " if num_images is not None:\n", + " if image_count > num_images:\n", + " break\n", + "\n", + " if show_progress:\n", + " pbar.close()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IWsmPssJJ09E" + }, + "source": [ + "" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "uODdkyjY1lap" + }, + "source": [ + "# Directory to save sample files to\n", + "sample_dir = \"inceptionv1_samples\"\n", + "try:\n", + " os.mkdir(sample_dir)\n", + "except:\n", + " pass\n", + "\n", + "# Collect samples & optionally attributions as well\n", + "capture_activation_samples(\n", + " loader=image_loader,\n", + " model=sample_model,\n", + " targets=sample_targets,\n", + " target_names=sample_target_names,\n", + " attr_model=sample_model_attr,\n", + " attr_targets=sample_attr_targets,\n", + " input_device=device,\n", + " sample_dir=sample_dir,\n", + " show_progress=True,\n", + " collect_attributions=collect_attributions,\n", + " logit_target=sample_logit_target,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eMrBUaPi97fF" + }, + "source": [ + "Now that we've collected our samples, we need to combine them into a single tensor. Below we use the `consolidate_samples` function to load each list of tensor samples, and then concatinate them into a single tensor." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "LaFglPVYKbXj" + }, + "source": [ + "def consolidate_samples(\n", + " sample_dir: str,\n", + " sample_basename: str = \"\",\n", + " dim: int = 1,\n", + " num_files: Optional[int] = None,\n", + " show_progress: bool = False,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " Combine samples collected from capture_activation_samples into a single tensor\n", + " with a shape of [n_channels, n_samples].\n", + "\n", + " Args:\n", + "\n", + " sample_dir (str): The directory where activation samples where saved.\n", + " sample_basename (str, optional): If samples from different layers are present\n", + " in sample_dir, then you can use samples from only a specific layer by\n", + " specifying the basename that samples of the same layer share.\n", + " Default: \"\"\n", + " dim (int, optional): The dimension to concatinate the samples together on.\n", + " Default: 1\n", + " show_progress (bool, optional): Whether or not to show progress.\n", + " Default: False\n", + "\n", + " Returns:\n", + " sample_tensor (torch.Tensor): A tensor containing all the specified sample\n", + " tensors with a shape of [n_channels, n_samples].\n", + " \"\"\"\n", + "\n", + " assert os.path.isdir(sample_dir)\n", + "\n", + " tensor_samples = [\n", + " os.path.join(sample_dir, name)\n", + " for name in os.listdir(sample_dir)\n", + " if sample_basename.lower() in name.lower()\n", + " and os.path.isfile(os.path.join(sample_dir, name))\n", + " ]\n", + " assert len(tensor_samples) > 0\n", + "\n", + " if show_progress:\n", + " total = len(tensor_samples) if num_files is None else num_files # type: ignore\n", + " pbar = tqdm(total=total, unit=\" sample batches collected\")\n", + "\n", + " samples: List[torch.Tensor] = []\n", + " for file in tensor_samples:\n", + " sample_batch = torch.load(file)\n", + " for s in sample_batch:\n", + " samples += [s.cpu()]\n", + " if show_progress:\n", + " pbar.update(1)\n", + "\n", + " if show_progress:\n", + " pbar.close()\n", + " return torch.cat(samples, dim)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "BKUPszVR1Ew-" + }, + "source": [ + "# Combine our newly collected samples into single tensors.\n", + "# We load the sample tensors from sample_dir and then\n", + "# concatenate them.\n", + "\n", + "for name in sample_target_names:\n", + " print(\"Combining \" + name + \" samples:\")\n", + " activation_samples = consolidate_samples(\n", + " sample_dir=sample_dir,\n", + " sample_basename=name + \"_activations\",\n", + " dim=1,\n", + " show_progress=True,\n", + " )\n", + " if collect_attributions:\n", + " sample_attributions = consolidate_samples(\n", + " sample_dir=sample_dir,\n", + " sample_basename=name + \"_attributions\",\n", + " dim=0,\n", + " show_progress=True,\n", + " )\n", + "\n", + " # Save the results\n", + " torch.save(activation_samples, name + \"activation_samples.pt\")\n", + " if collect_attributions:\n", + " torch.save(sample_attributions, name + \"attribution_samples.pt\")" + ], + "execution_count": null, + "outputs": [] + } + ] +}