Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
6892f3c
Add model linearization, and expanded weights spatial positions
ProGamerGov Dec 27, 2020
336c62f
Spelling fixes and add top channels to tutorial
ProGamerGov Dec 28, 2020
dba1e48
Optionally remove nonlinear ReLU & full expanded weights test
ProGamerGov Dec 29, 2020
83bc5b8
Use torch.norm for PyTorch 1.3.0 only
ProGamerGov Dec 29, 2020
94fe929
Improve weight vis tutorial descriptions
ProGamerGov Dec 30, 2020
690fc79
Move ignore_layer & max2avg_pool2d to utils/models
ProGamerGov Dec 31, 2020
caf0cc7
Improve weight vis tutorial
ProGamerGov Jan 1, 2021
620f1ee
Improve intro & expanded weight description
ProGamerGov Jan 1, 2021
7515fd2
Improvements
ProGamerGov Jan 2, 2021
2e86d98
Merge branch 'optim-wip' into optim-wip-circuits
ProGamerGov Jan 4, 2021
752186a
Replace round with math.ceil in CenterCrop
ProGamerGov Jan 6, 2021
4116ecd
Add optional center crop offset parameter for uneven sides
ProGamerGov Jan 6, 2021
f1f73b0
Improve center crop parameter description
ProGamerGov Jan 8, 2021
ebaacbd
Update weight vis tutorial for new center crop
ProGamerGov Jan 8, 2021
891fd97
Remove suppression of PyTorch UserWarnings
ProGamerGov Jan 9, 2021
18a9d8c
Changes based on feedback
ProGamerGov Jan 9, 2021
49bee39
Update model factory tests
ProGamerGov Jan 9, 2021
70bd89c
Merge branch 'optim-wip' of https://github.com/pytorch/captum into op…
Jan 11, 2021
899ee6f
Update weight vis tutorial with colorspace fix
ProGamerGov Jan 11, 2021
bf5b991
Improve weight vis tutorial
ProGamerGov Jan 15, 2021
9a59014
Minor fixes & improvements to tutorial notebook
ProGamerGov Jan 16, 2021
790066b
Link spatial positions back to weight heatmap
ProGamerGov Jan 16, 2021
441b751
Test version check improvements
ProGamerGov Jan 16, 2021
7212dd8
Changes based on feedback
ProGamerGov Jan 20, 2021
7071396
Fix tests and InceptionV1 model
ProGamerGov Jan 20, 2021
c141bd7
Remove non-working check
ProGamerGov Jan 20, 2021
4237205
Changes based on feedback
ProGamerGov Jan 21, 2021
58a8c3b
Remove redundant skip_layer function
ProGamerGov Jan 21, 2021
4d3c686
Re-add skip_layer function
ProGamerGov Jan 21, 2021
2570eed
Improve replace layers
ProGamerGov Jan 22, 2021
f447533
Remove param transfer from skip_layers
ProGamerGov Jan 22, 2021
d472a91
Changes based on feedback part 1
ProGamerGov Jan 24, 2021
88a88ed
Remove placeholder Any type hints & fix instance creation
ProGamerGov Jan 24, 2021
908bec0
Add type hints for layers, model, and layer instances
ProGamerGov Jan 24, 2021
075ac59
max2avg_pool2d -> replace_max_with_avgconst_pool2d
ProGamerGov Jan 24, 2021
73175c0
Improve _check_layer_in_model test
ProGamerGov Jan 24, 2021
537fe79
Remove unused type hint import
ProGamerGov Jan 24, 2021
0a8b6e2
Revert layer check test and add type hints
ProGamerGov Jan 24, 2021
866faac
Add number of expected layers to tests
ProGamerGov Jan 24, 2021
ba07685
Change _check_layer_in_model based on feedback
ProGamerGov Jan 24, 2021
130513b
Changes to tutorial based on feedback
ProGamerGov Jan 25, 2021
fae868b
Address notebook feedback - part 2
ProGamerGov Jan 25, 2021
b9cbf02
Better wording for new sentence
ProGamerGov Jan 25, 2021
4d85d13
get_expanded_weights -> extract_expanded_weights
ProGamerGov Jan 26, 2021
83921ed
Improve weight vis notebook introduction
ProGamerGov Jan 26, 2021
f45b208
Change NMF link location
ProGamerGov Jan 26, 2021
bc7fa97
Add return to _check_layer_in_model test
ProGamerGov Jan 26, 2021
162e47a
Remove return
ProGamerGov Jan 26, 2021
2fc46d0
GPU Test Fix
ProGamerGov Jan 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions captum/optim/_param/image/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,22 @@ class CenterCrop(torch.nn.Module):
pixels_from_edges (bool, optional): Whether to treat crop size
values as the number of pixels from the tensor's edge, or an
exact shape in the center.
offset_left (bool, optional): If the cropped away sides are not
equal in size, offset to the left. Default is set to False
for offseting to the right. This parameter is only valid when
pixels_from_edges is False.
"""

def __init__(
self, size: IntSeqOrIntType = 0, pixels_from_edges: bool = False
self,
size: IntSeqOrIntType = 0,
pixels_from_edges: bool = False,
offset_left: bool = False,
) -> None:
super(CenterCrop, self).__init__()
self.crop_vals = size
self.pixels_from_edges = pixels_from_edges
self.offset_left = offset_left

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -153,11 +161,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
tensor (torch.Tensor): A center cropped tensor.
"""

return center_crop(input, self.crop_vals, self.pixels_from_edges)
return center_crop(
input, self.crop_vals, self.pixels_from_edges, self.offset_left
)


def center_crop(
input: torch.Tensor, crop_vals: IntSeqOrIntType, pixels_from_edges: bool = False
input: torch.Tensor,
crop_vals: IntSeqOrIntType,
pixels_from_edges: bool = False,
offset_left: bool = False,
) -> torch.Tensor:
"""
Center crop a specified amount from a tensor.
Expand All @@ -167,6 +180,10 @@ def center_crop(
pixels_from_edges (bool, optional): Whether to treat crop size
values as the number of pixels from the tensor's edge, or an
exact shape in the center.
offset_left (bool, optional): If the cropped away sides are not
equal in size, offset to the left. Default is set to False
for offseting to the right. This parameter is only valid when
pixels_from_edges is False.
Returns:
*tensor*: A center cropped tensor.
"""
Expand All @@ -188,8 +205,12 @@ def center_crop(
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
x = input[..., sh : sh + h_crop, sw : sw + w_crop]
else:
h_crop = h - int(round((h - crop_vals[0]) / 2.0))
w_crop = w - int(round((w - crop_vals[1]) / 2.0))
h_crop = h - int(math.ceil((h - crop_vals[0]) / 2.0))
w_crop = w - int(math.ceil((w - crop_vals[1]) / 2.0))
if h % 2 == 0 and crop_vals[0] % 2 != 0 or h % 2 != 0 and crop_vals[0] % 2 == 0:
h_crop = h_crop + 1 if offset_left else h_crop
if w % 2 == 0 and crop_vals[1] % 2 != 0 or w % 2 != 0 and crop_vals[1] % 2 == 0:
w_crop = w_crop + 1 if offset_left else w_crop
x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop]
return x

Expand Down
75 changes: 73 additions & 2 deletions captum/optim/_utils/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -70,8 +70,12 @@ class ReluLayer(nn.Module):
Basic Hookable & Replaceable ReLU layer.
"""

def __init__(self, inplace: bool = True) -> None:
super(ReluLayer, self).__init__()
self.inplace = inplace

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.relu(input, inplace=True)
return F.relu(input, inplace=self.inplace)


def replace_layers(model, old_layer=ReluLayer, new_layer=RedirectedReluLayer) -> None:
Expand Down Expand Up @@ -168,3 +172,70 @@ def collect_activations(
catch_activ = ActivationFetcher(model, targets)
activ_out = catch_activ(model_input)
return activ_out


def max2avg_pool2d(model, value: Optional[Any] = float("-inf")) -> None:
"""
Replace all nonlinear MaxPool2d layers with their linear AvgPool2d equivalents.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like here we also do replacement of layers but don't use replace_layers function ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like here we also do replacement of layers but don't use replace_layers function ?

Yeah I had to make a separate function because the replace_layers function doesn't copy the layer parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace_layers is a generic name if it doesn't replace any layer then we can probably give a more specific name to it or actually try to access all parameters and copy all parameters from one module to another.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I think I've figured out to replace layers while attempting to copy over any parameters that are shared!

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK Let me know what you think of the new replace_layers function! As long as vars(layer1) matches the layer2 __init__ function signature, parameters are transferred.

This allows us to ignore nonlinear values when calculating expanded weights.

Args:
model (nn.Module): A PyTorch model instance.
value (Any): Used to return any padding that's meant to be ignored by
pooling layers back to zero.
"""

class AvgPool2dLayer(torch.nn.Module):
def __init__(
self,
kernel_size: Union[int, Tuple[int, ...]] = 2,
stride: Optional[Union[int, Tuple[int, ...]]] = 2,
padding: Union[int, Tuple[int, ...]] = 0,
ceil_mode: bool = False,
value: Optional[Any] = None,
) -> None:
super().__init__()
self.avgpool = torch.nn.AvgPool2d(
kernel_size=kernel_size,
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
)
self.value = value

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.avgpool(x)
if self.value is not None:
x[x == self.value] = 0.0
return x

for name, child in model._modules.items():
if isinstance(child, torch.nn.MaxPool2d):
new_layer = AvgPool2dLayer(
kernel_size=child.kernel_size,
stride=child.stride,
padding=child.padding,
ceil_mode=child.ceil_mode,
value=value,
)
setattr(model, name, new_layer)
elif child is not None:
max2avg_pool2d(child)


def ignore_layer(model, layer) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: type hints for the args are missing ?
nit: skip_layer(model, layer) ?

Currently we don't use it, right?

I think we need to find the best way of exposing replace_layers, skip_layer and similar helper functions if they are going to be used by the user. Are these functions called externally ?

If the functions are private they should be underscored.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, both replace_layers and skip_layer are meant to help users in getting their model ready for Captum in some way (adding redirected relu, removing nonlinear layers, etc...).

Copy link
Contributor

@NarineK NarineK Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProGamerGov, do we already have a skip_layer function somewhere else ? My suggestion was to rename ignore_layer to skip_layer because it might sounds better.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK I removed the skip_layer function as it was just the replace_layers with SkipLayer as the layer argument, so I removed it. I can re-add it if you think it's useful to keep so that users don't have to import both replace_layers and SkipLayer to accomplish the same task. I initially had thought that it wasn't generic enough, but it's pretty generic.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be useful to keep, so I'll re-add it. The task of replacing ReLU layers with SkipLayers seems like it'll be common to warrant being it's own function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProGamerGov, it's good to reduce the redundancy. I thought that Skip is a better synonym of Ignore. I didn't see that you originally removed a function called skip_layer.

"""
Replace target layers with layers that do nothing.
This is useful for removing the nonlinear ReLU
layers when creating expanded weights.

Args:
model (nn.Module): A PyTorch model instance.
layer (nn.Module): A layer class type.
"""

class IgnoreLayer(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x

replace_layers(model, layer, IgnoreLayer)
32 changes: 27 additions & 5 deletions tests/optim/helpers/numpy_transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import List, Optional, Tuple, Union, cast

import numpy as np
Expand Down Expand Up @@ -63,14 +64,22 @@ class CenterCrop:
pixels_from_edges (bool, optional): Whether to treat crop size values
as the number of pixels from the tensor's edge, or an exact shape
in the center.
offset_left (bool, optional): If the cropped away sides are not
equal in size, offset to the left. Default is set to False
for offseting to the right. This parameter is only valid when
pixels_from_edges is False.
"""

def __init__(
self, size: IntSeqOrIntType = 0, pixels_from_edges: bool = False
self,
size: IntSeqOrIntType = 0,
pixels_from_edges: bool = False,
offset_left: bool = False,
) -> None:
super(CenterCrop, self).__init__()
self.crop_vals = size
self.pixels_from_edges = pixels_from_edges
self.offset_left = offset_left

def forward(self, input: np.ndarray) -> np.ndarray:
"""
Expand All @@ -81,11 +90,16 @@ def forward(self, input: np.ndarray) -> np.ndarray:
tensor (array): A center cropped tensor.
"""

return center_crop(input, self.crop_vals, self.pixels_from_edges)
return center_crop(
input, self.crop_vals, self.pixels_from_edges, self.offset_left
)


def center_crop(
input: np.ndarray, crop_vals: IntSeqOrIntType, pixels_from_edges: bool = False
input: np.ndarray,
crop_vals: IntSeqOrIntType,
pixels_from_edges: bool = False,
offset_left: bool = False,
) -> np.ndarray:
"""
Center crop a specified amount from a array.
Expand All @@ -95,6 +109,10 @@ def center_crop(
pixels_from_edges (bool, optional): Whether to treat crop size values
as the number of pixels from the array's edge, or an exact shape
in the center.
offset_left (bool, optional): If the cropped away sides are not
equal in size, offset to the left. Default is set to False
for offseting to the right. This parameter is only valid when
pixels_from_edges is False.
Returns:
*array*: A center cropped array.
"""
Expand All @@ -116,8 +134,12 @@ def center_crop(
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
x = input[..., sh : sh + h_crop, sw : sw + w_crop]
else:
h_crop = h - int(round((h - crop_vals[0]) / 2.0))
w_crop = w - int(round((w - crop_vals[1]) / 2.0))
h_crop = h - int(math.ceil((h - crop_vals[0]) / 2.0))
w_crop = w - int(math.ceil((w - crop_vals[1]) / 2.0))
if h % 2 == 0 and crop_vals[0] % 2 != 0 or h % 2 != 0 and crop_vals[0] % 2 == 0:
h_crop = h_crop + 1 if offset_left else h_crop
if w % 2 == 0 and crop_vals[1] % 2 != 0 or w % 2 != 0 and crop_vals[1] % 2 == 0:
w_crop = w_crop + 1 if offset_left else w_crop
x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop]
return x

Expand Down
50 changes: 42 additions & 8 deletions tests/optim/param/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ def test_center_crop_one_number_exact(self) -> None:
[
torch.tensor(
[
[0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 1.0, 0.0, 1.0],
[0.0, 1.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 1.0, 1.0, 0.0],
[1.0, 0.0, 1.0, 1.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0],
]
)
]
Expand Down Expand Up @@ -224,6 +224,24 @@ def test_center_crop_two_numbers_exact(self) -> None:
).unsqueeze(0)
assertTensorAlmostEqual(self, cropped_tensor, expected_tensor, 0)

def test_center_crop_offset_left_uneven_sides(self) -> None:
crop_mod = transform.CenterCrop(
[5, 5], pixels_from_edges=False, offset_left=True
)
x = torch.ones(1, 3, 5, 5)
px = F.pad(x, (5, 4, 5, 4), value=float("-inf"))
cropped_tensor = crop_mod(px)
assertTensorAlmostEqual(self, x, cropped_tensor)

def test_center_crop_offset_left_even_sides(self) -> None:
crop_mod = transform.CenterCrop(
[5, 5], pixels_from_edges=False, offset_left=True
)
x = torch.ones(1, 3, 5, 5)
px = F.pad(x, (5, 5, 5, 5), value=float("-inf"))
cropped_tensor = crop_mod(px)
assertTensorAlmostEqual(self, x, cropped_tensor)


class TestCenterCropFunction(BaseTest):
def test_center_crop_one_number(self) -> None:
Expand Down Expand Up @@ -286,11 +304,11 @@ def test_center_crop_one_number_exact(self) -> None:
[
torch.tensor(
[
[0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 1.0, 0.0, 1.0],
[0.0, 1.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 1.0, 1.0, 0.0],
[1.0, 0.0, 1.0, 1.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0],
]
)
]
Expand Down Expand Up @@ -319,6 +337,22 @@ def test_center_crop_two_numbers_exact(self) -> None:
).unsqueeze(0)
assertTensorAlmostEqual(self, cropped_tensor, expected_tensor)

def test_center_crop_offset_left_uneven_sides(self) -> None:
x = torch.ones(1, 3, 5, 5)
px = F.pad(x, (5, 4, 5, 4), value=float("-inf"))
cropped_tensor = transform.center_crop(
px, crop_vals=[5, 5], pixels_from_edges=False, offset_left=True
)
assertTensorAlmostEqual(self, x, cropped_tensor)

def test_center_crop_offset_left_even_sides(self) -> None:
x = torch.ones(1, 3, 5, 5)
px = F.pad(x, (5, 5, 5, 5), value=float("-inf"))
cropped_tensor = transform.center_crop(
px, crop_vals=[5, 5], pixels_from_edges=False, offset_left=True
)
assertTensorAlmostEqual(self, x, cropped_tensor)


class TestBlendAlpha(BaseTest):
def test_blend_alpha(self) -> None:
Expand Down
32 changes: 32 additions & 0 deletions tests/optim/utils/test_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import captum.optim._utils.circuits as circuits
from captum.optim._models.inception_v1 import googlenet
from captum.optim._utils.models import RedirectedReluLayer, ignore_layer, max2avg_pool2d
from tests.helpers.basic import BaseTest


Expand Down Expand Up @@ -45,6 +46,37 @@ def test_get_expanded_weights_crop_two_int(self) -> None:
)
self.assertEqual(list(output_tensor.shape), [480, 256, 5, 5])

def test_get_expanded_nonlinear_top_connections(self) -> None:
if torch.__version__ == "1.2.0":
raise unittest.SkipTest(
"Skipping get_expanded_weights nonlinear_top_connections test"
+ " due to insufficient Torch version."
)

if torch.__version__ == "1.3.0":
norm_func = torch.norm
else:
norm_func = torch.linalg.norm
model = googlenet(pretrained=True)
max2avg_pool2d(model)
ignore_layer(model, RedirectedReluLayer)
output_tensor = circuits.get_expanded_weights(
model, model.pool3, model.mixed4a, 5
)
self.assertEqual(list(output_tensor.shape), [508, 480, 5, 5])

top_connected_neurons = torch.argsort(
torch.stack(
[
-norm_func(output_tensor[i, 379, :, :])
for i in range(output_tensor.shape[0])
]
)
)[:10].tolist()

expected_list = [50, 437, 96, 398, 434, 423, 408, 436, 424, 168]
self.assertEqual(top_connected_neurons, expected_list)


if __name__ == "__main__":
unittest.main()
28 changes: 28 additions & 0 deletions tests/optim/utils/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,5 +317,33 @@ def test_collect_activations(self) -> None:
self.assertEqual(list(cast(torch.Tensor, m4d_activ).shape), [1, 528, 14, 14])


class TestMax2AvgPool2d(BaseTest):
def test_max2avg_pool2d(self) -> None:
model = torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
)

model_utils.max2avg_pool2d(model)

test_tensor = torch.randn(128, 32, 16, 16)
test_tensor = F.pad(test_tensor, (0, 1, 0, 1), value=float("-inf"))
out_tensor = model(test_tensor)

avg_pool = torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=0)
expected_tensor = avg_pool(test_tensor)
expected_tensor[expected_tensor == float("-inf")] = 0.0

assertTensorAlmostEqual(self, out_tensor, expected_tensor, 0)


class TestIgnoreLayer(BaseTest):
def test_ignore_layer(self) -> None:
model = torch.nn.Sequential(torch.nn.ReLU())
x = torch.randn(1, 3, 4, 4)
model_utils.ignore_layer(model, torch.nn.ReLU)
output_tensor = model(x)
assertTensorAlmostEqual(self, x, output_tensor, 0)


if __name__ == "__main__":
unittest.main()
Loading