Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
1d0d3ae
Add SharedImage, RGBToBGR, & ScaleInputRange
ProGamerGov Nov 28, 2020
5802844
Fix class formatting
ProGamerGov Nov 28, 2020
334ab79
Improvements to SharedImage & more asserts
ProGamerGov Nov 30, 2020
6f26c1c
Fix lint error
ProGamerGov Nov 30, 2020
a1f51f9
Add offset parameter to SharedImage
ProGamerGov Dec 1, 2020
98a5f7c
Add missing type hint
ProGamerGov Dec 1, 2020
ae94426
Fix type hint
ProGamerGov Dec 1, 2020
6baf6c1
Correct type hint
ProGamerGov Dec 1, 2020
29148da
Remove Optional from type hint
ProGamerGov Dec 1, 2020
0e327cd
Fix size mismatch when FFTImage init height > width
ProGamerGov Dec 8, 2020
eaaedcd
Merge branch 'optim-wip' into optim-wip
ProGamerGov Dec 9, 2020
ac4954c
Add tests for SharedImage & 2 new transforms
ProGamerGov Dec 9, 2020
6510d8f
Preliminary circuit functions, fixes, & param name change
ProGamerGov Dec 11, 2020
3df40bc
Add ChannelReducer
ProGamerGov Dec 11, 2020
0f58272
Fix test name
ProGamerGov Dec 11, 2020
e5ede3a
Reorganize new code
ProGamerGov Dec 12, 2020
f54a989
Move ActivationCatcher test
ProGamerGov Dec 12, 2020
a3e1e60
Forgot to move import
ProGamerGov Dec 12, 2020
3da14cb
Fix import order
ProGamerGov Dec 12, 2020
8ed6cd3
Fix import
ProGamerGov Dec 12, 2020
43ed093
Update inception_v1.py
ProGamerGov Dec 12, 2020
1c4f533
Update inception_v1.py
ProGamerGov Dec 12, 2020
c783a35
Fix model class & improvements
ProGamerGov Dec 13, 2020
443ea1d
Fix model lint
ProGamerGov Dec 13, 2020
55610ad
Fix layer order in InceptionModule & fix expanded weights
ProGamerGov Dec 13, 2020
cf88521
Add weight visualization tutorial
ProGamerGov Dec 13, 2020
50ac26f
Change some tutorial text
ProGamerGov Dec 13, 2020
206880d
Improve ChannelReducer
ProGamerGov Dec 13, 2020
745d322
Fixes & improvements
ProGamerGov Dec 13, 2020
129f6b5
Improve weight visualization tutorial
ProGamerGov Dec 14, 2020
3e17421
Fix some objective call functions
ProGamerGov Dec 14, 2020
7c4bd96
Additional ChannelReducer tests and CustomModule tutorial
ProGamerGov Dec 14, 2020
e205f5c
Move posneg to reducer
ProGamerGov Dec 14, 2020
024b87e
Remove duplicate class
ProGamerGov Dec 15, 2020
bd1b4ba
Add missing type hints & tutorial improvements
ProGamerGov Dec 15, 2020
b7c067c
Add ability to hide progress bar & weight vis update
ProGamerGov Dec 16, 2020
8b03cc2
Make tqdm optional
ProGamerGov Dec 16, 2020
6bb6aa1
Change description based on comment
ProGamerGov Dec 17, 2020
40e26e6
Fixes & improvements
ProGamerGov Dec 17, 2020
61cb050
Linting
ProGamerGov Dec 17, 2020
58d2c5a
Update test_reducer.py
ProGamerGov Dec 17, 2020
a240d10
Changes to Custom Modules tutorial based on feedback
ProGamerGov Dec 17, 2020
bbfd07d
Remove set_image functions
ProGamerGov Dec 17, 2020
c3b1359
Minor correction & remove PyTorch UserWarnings
ProGamerGov Dec 17, 2020
2520d1c
Add simple n_channels to RGB function
ProGamerGov Dec 18, 2020
ff7cdff
Implement changes based on feedback
ProGamerGov Dec 19, 2020
a276a8c
Implement additional changes based on feedback
ProGamerGov Dec 19, 2020
1f18010
Fix neuron objectives
ProGamerGov Dec 20, 2020
4ae005e
DirectionNeuron -> NeuronDirection
ProGamerGov Dec 21, 2020
995eb36
Fix 4 dimensional reflection padding
ProGamerGov Dec 21, 2020
50a726f
Implement first batch of changes based on feedback
ProGamerGov Dec 22, 2020
6235a2f
Fix lint errors
ProGamerGov Dec 22, 2020
2142dd7
Fix test
ProGamerGov Dec 22, 2020
c1a0b3d
Second batch of changes based on feedback
ProGamerGov Dec 22, 2020
eae361d
Fix flake8 error
ProGamerGov Dec 22, 2020
51886e1
Cast outputs to avoid Mypy error
ProGamerGov Dec 22, 2020
b828cbc
Third batch of changes based on feedback
ProGamerGov Dec 22, 2020
123e70a
Fix lint & test errors
ProGamerGov Dec 22, 2020
edeba3b
Remove 'object' reference from ChannelReducer
ProGamerGov Dec 22, 2020
859e6c4
Replace 4D reflect pad with NumPy symmetric pad + autograd
ProGamerGov Dec 22, 2020
c3dafad
Remove NumPy import
ProGamerGov Dec 22, 2020
ff63e1d
Remove old function import
ProGamerGov Dec 22, 2020
80c0e8f
Update offset test code
ProGamerGov Dec 22, 2020
c8c10dc
Reorganize utils
ProGamerGov Dec 23, 2020
58cefae
Update transform.py
ProGamerGov Dec 23, 2020
dd3fd43
Update test_transforms.py
ProGamerGov Dec 23, 2020
fb82ac4
Changes based on feedback
ProGamerGov Dec 23, 2020
52466c9
Fix lint errors
ProGamerGov Dec 23, 2020
7f0f4b5
Fix imports
ProGamerGov Dec 23, 2020
fe44e49
Remove unused code from weight-viz tutorial
ProGamerGov Dec 23, 2020
596b5f6
Make posneg function use F.relu
ProGamerGov Dec 23, 2020
17d7516
Changes based on feedback
ProGamerGov Dec 23, 2020
9b095bd
Fix lint errors
ProGamerGov Dec 23, 2020
987e28e
More changes based on feedback
ProGamerGov Dec 23, 2020
7b34517
Update test_transforms.py (#144)
ProGamerGov Dec 23, 2020
5e07a8f
Fix asserts
ProGamerGov Dec 23, 2020
c8d90b4
Oops: aelf -> self
ProGamerGov Dec 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion captum/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@
from captum.optim._param.image import images # noqa: F401
from captum.optim._param.image import transform # noqa: F401
from captum.optim._param.image.images import ImageTensor # noqa: F401
from captum.optim._utils import models # noqa: F401
from captum.optim._utils import circuits, models, reducer # noqa: F401
from captum.optim._utils.image.common import nchannels_to_rgb # noqa: F401
from captum.optim._utils.image.common import weights_to_heatmap_2d # noqa: F401
34 changes: 18 additions & 16 deletions captum/optim/_core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn

from captum.optim._utils.images import get_neuron_pos
from captum.optim._utils.image.common import get_neuron_pos
from captum.optim._utils.typing import ModuleOutputMapping


Expand Down Expand Up @@ -66,18 +66,16 @@ def __init__(
self.x = x
self.y = y

# ensure channel_index will be valid
assert self.channel_index < self.target.out_channels

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
assert activations is not None
assert self.channel_index < activations.shape[1]
assert len(activations.shape) == 4 # assume NCHW
_x, _y = get_neuron_pos(
activations.size(2), activations.size(3), self.x, self.y
)

return activations[:, self.channel_index, _x, _y]
return activations[:, self.channel_index, _x : _x + 1, _y : _y + 1]


class DeepDream(Loss):
Expand All @@ -98,7 +96,7 @@ class TotalVariation(Loss):
https://arxiv.org/abs/1412.0035
"""

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
x_diff = activations[..., 1:, :] - activations[..., :-1, :]
y_diff = activations[..., :, 1:] - activations[..., :, :-1]
Expand All @@ -115,7 +113,7 @@ def __init__(self, target: nn.Module, constant: float = 0.0) -> None:
self.target = target
self.constant = constant

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
return torch.abs(activations - self.constant).sum()

Expand All @@ -132,7 +130,7 @@ def __init__(
self.constant = constant
self.epsilon = epsilon

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
activations = (activations - self.constant).sum()
return torch.sqrt(self.epsilon + activations)
Expand All @@ -145,7 +143,7 @@ class Diversity(Loss):
https://distill.pub/2017/feature-visualization/#diversity
"""

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
return -sum(
[
Expand Down Expand Up @@ -260,7 +258,7 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
return torch.cosine_similarity(self.direction, activations)


class DirectionNeuron(Loss):
class NeuronDirection(Loss):
"""
Visualize a single (x, y) position for a direction vector.
Carter, et al., "Activation Atlas", Distill, 2019.
Expand All @@ -271,16 +269,16 @@ def __init__(
self,
target: nn.Module,
vec: torch.Tensor,
channel_index: int,
x: Optional[int] = None,
y: Optional[int] = None,
channel_index: Optional[int] = None,
) -> None:
super(Loss, self).__init__()
self.target = target
self.direction = vec.reshape((1, -1, 1, 1))
self.channel_index = channel_index
self.x = x
self.y = y
self.channel_index = channel_index

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
Expand All @@ -290,8 +288,10 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
_x, _y = get_neuron_pos(
activations.size(2), activations.size(3), self.x, self.y
)
activations = activations[:, self.channel_index, _x, _y]
return torch.cosine_similarity(self.direction, activations[None, None, None])
activations = activations[:, :, _x : _x + 1, _y : _y + 1]
if self.channel_index is not None:
activations = activations[:, self.channel_index, ...][:, None, ...]
return torch.cosine_similarity(self.direction, activations)


class TensorDirection(Loss):
Expand Down Expand Up @@ -361,7 +361,9 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
_x, _y = get_neuron_pos(
activations.size(2), activations.size(3), self.x, self.y
)
activations = activations[..., _x, _y].squeeze() * self.weights
activations = (
activations[..., _x : _x + 1, _y : _y + 1].squeeze() * self.weights
)
else:
activations = activations[
..., self.y : self.y + self.wy, self.x : self.x + self.wx
Expand Down
24 changes: 18 additions & 6 deletions captum/optim/_core/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
try:
from tqdm.auto import tqdm
except (ImportError, AssertionError):
print("The tqdm package is required to use Captum's Optim library")
print(
"The tqdm package is required to use captum.optim's"
+ " n_steps stop criteria with progress bar"
)

from captum.optim._core.output_hook import AbortForwardException, ModuleOutputsHook
from captum.optim._param.image.images import InputParameterization, NaturalImage
Expand Down Expand Up @@ -147,23 +150,32 @@ def optimize(
return history


def n_steps(n: int) -> StopCriteria:
def n_steps(n: int, show_progress: bool = True) -> StopCriteria:
"""StopCriteria generator that uses number of steps as a stop criteria.
Args:
n (int): Number of steps to run optimization.
show_progress (bool, optional): Whether or not to show progress bar.
Default: True
Returns:
*StopCriteria* callable
"""
pbar = tqdm(total=n, unit="step")

if show_progress:
pbar = tqdm(total=n, unit=" step")

def continue_while(step, obj, history, optim) -> bool:
if len(history) > 0:
pbar.set_postfix({"Objective": f"{history[-1].mean():.1f}"}, refresh=False)
if show_progress:
pbar.set_postfix(
{"Objective": f"{history[-1].mean():.1f}"}, refresh=False
)
if step < n:
pbar.update()
if show_progress:
pbar.update()
return True
else:
pbar.close()
if show_progress:
pbar.close()
return False

return continue_while
Expand Down
25 changes: 23 additions & 2 deletions captum/optim/_core/output_hook.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Iterable
from contextlib import suppress
from typing import Iterable, List, Union
from warnings import warn

import torch.nn as nn

# from clarity.pytorch import ModuleOutputMapping
from captum.optim._utils.typing import ModelInputType, ModuleOutputMapping


class AbortForwardException(Exception):
Expand Down Expand Up @@ -88,3 +89,23 @@ def remove_hooks(self) -> None:
def __del__(self) -> None:
# print(f"DEL HOOKS!: {list(self.outputs.keys())}")
self.remove_hooks()


class ActivationFetcher:
"""
Simple module for collecting activations from model targets.
"""

def __init__(self, model, targets: Union[nn.Module, List[nn.Module]]) -> None:
super(ActivationFetcher, self).__init__()
self.model = model
self.layers = ModuleOutputsHook(targets)

def __call__(self, input_t: ModelInputType) -> ModuleOutputMapping:
try:
with suppress(AbortForwardException):
self.model(input_t)
activations = self.layers.consume_outputs()
finally:
self.layers.remove_hooks()
return activations
14 changes: 7 additions & 7 deletions captum/optim/_models/inception_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(
out_channels=64,
kernel_size=(7, 7),
stride=(2, 2),
padding=(3, 3),
groups=1,
bias=True,
)
Expand All @@ -89,7 +88,6 @@ def __init__(
out_channels=192,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
groups=1,
bias=True,
)
Expand Down Expand Up @@ -132,6 +130,7 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor:

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._transform_input(x)
x = F.pad(x, (2, 3, 2, 3))
x = self.conv1(x)
x = self.conv1_relu(x)
x = F.pad(x, (0, 1, 0, 1), value=float("-inf"))
Expand All @@ -140,6 +139,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

x = self.conv2(x)
x = self.conv2_relu(x)
x = F.pad(x, (1, 1, 1, 1))
x = self.conv3(x)
x = self.conv3_relu(x)
x = self.localresponsenorm2(x)
Expand Down Expand Up @@ -214,7 +214,6 @@ def __init__(
out_channels=c3x3,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
groups=1,
bias=True,
)
Expand All @@ -234,7 +233,6 @@ def __init__(
out_channels=c5x5,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2),
groups=1,
bias=True,
)
Expand All @@ -257,18 +255,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

c3x3 = self.conv_3x3_reduce(x)
c3x3 = self.conv_3x3_reduce_relu(c3x3)
c3x3 = F.pad(c3x3, (1, 1, 1, 1))
c3x3 = self.conv_3x3(c3x3)
c3x3 = self.conv_3x3_relu(c3x3)

c5x5 = self.conv_5x5_reduce(x)
c5x5 = self.conv_5x5_reduce_relu(c5x5)
c5x5 = F.pad(c5x5, (2, 2, 2, 2))
c5x5 = self.conv_5x5(c5x5)
c5x5 = self.conv_5x5_relu(c5x5)

px = self.pool_proj(x)
px = self.pool_proj_relu(px)
px = F.pad(px, (1, 1, 1, 1), value=float("-inf"))
px = F.pad(x, (1, 1, 1, 1), value=float("-inf"))
px = self.pool(px)
px = self.pool_proj(px)
px = self.pool_proj_relu(px)
return torch.cat([c1x1, c3x3, c5x5, px], dim=1)


Expand Down
Loading