Skip to content

Commit b02a9f9

Browse files
authored
Merge pull request #496 from greentfrapp/optim-wip
Preliminary PR on optim-wip
2 parents bc4cd67 + b54e23d commit b02a9f9

29 files changed

+975
-882
lines changed

captum/optim/__init__.py

100644100755
Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,8 @@
1-
from typing import Dict, Optional, Union, Callable, Iterable
2-
from typing_extensions import Protocol
3-
4-
import torch
5-
import torch.nn as nn
6-
7-
ParametersForOptimizers = Iterable[Union[torch.Tensor, Dict[str, torch.tensor]]]
8-
9-
10-
class HasLoss(Protocol):
11-
def loss(self) -> torch.Tensor:
12-
...
13-
14-
15-
class Parameterized(Protocol):
16-
parameters: ParametersForOptimizers
17-
18-
19-
class Objective(Parameterized, HasLoss):
20-
def cleanup(self):
21-
pass
22-
23-
24-
ModuleOutputMapping = Dict[nn.Module, Optional[torch.Tensor]]
25-
26-
StopCriteria = Callable[[int, Objective, torch.optim.Optimizer], bool]
27-
1+
"""optim submodule."""
2+
3+
import captum.optim._core.objectives as objectives # noqa: F401
4+
import captum.optim._param.image.images as images # noqa: F401
5+
import captum.optim._param.image.transform as transform # noqa: F401
6+
import captum.optim._utils.typing as typing # noqa: F401
7+
from captum.optim._core.objectives import InputOptimization # noqa: F401
8+
from captum.optim._param.image.images import ImageTensor # noqa: F401
File renamed without changes.

captum/optim/_core/objectives.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
"""captum.optim.objectives."""
2+
3+
from contextlib import suppress
4+
from typing import Callable, Iterable, List, Optional
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.optim as optim
9+
from tqdm.auto import tqdm
10+
11+
from captum.optim._core.output_hook import AbortForwardException, ModuleOutputsHook
12+
from captum.optim._param.image.images import InputParameterization, NaturalImage
13+
from captum.optim._param.image.transform import RandomAffine
14+
from captum.optim._utils.typing import (
15+
LossFunction,
16+
ModuleOutputMapping,
17+
Objective,
18+
Parameterized,
19+
SingleTargetLossFunction,
20+
StopCriteria,
21+
)
22+
23+
24+
class InputOptimization(Objective, Parameterized):
25+
"""
26+
Core function that optimizes an input to maximize a target (aka objective).
27+
This is similar to gradient-based methods for adversarial examples, such
28+
as FGSM. The code for this was based on the implementation by the authors of Lucid.
29+
For more details, see the following:
30+
https://github.com/tensorflow/lucid
31+
https://distill.pub/2017/feature-visualization/
32+
"""
33+
34+
def __init__(
35+
self,
36+
model: nn.Module,
37+
input_param: Optional[InputParameterization],
38+
transform: Optional[nn.Module],
39+
target_modules: Iterable[nn.Module],
40+
loss_function: LossFunction,
41+
):
42+
r"""
43+
Args:
44+
model (nn.Module): The reference to PyTorch model instance.
45+
input_param (nn.Module, optional): A module that generates an input,
46+
consumed by the model.
47+
transform (nn.Module, optional): A module that transforms or preprocesses
48+
the input before being passed to the model.
49+
target_modules (iterable of nn.Module): A list of targets, objectives that
50+
are used to compute the loss function.
51+
loss_function (callable): The loss function to minimize during optimization
52+
optimization.
53+
"""
54+
self.model = model
55+
self.hooks = ModuleOutputsHook(target_modules)
56+
self.input_param = input_param or NaturalImage((224, 224))
57+
self.transform = transform or RandomAffine(scale=True, translate=True)
58+
self.loss_function = loss_function
59+
60+
def loss(self) -> torch.Tensor:
61+
r"""Compute loss value for current iteration.
62+
Returns:
63+
*tensor* representing **loss**:
64+
- **loss** (*tensor*):
65+
Size of the tensor corresponds to the targets passed.
66+
"""
67+
image = self.input_param()._t[None, ...]
68+
69+
if self.transform:
70+
image = self.transform(image)
71+
72+
with suppress(AbortForwardException):
73+
_unreachable = self.model(image) # noqa: F841
74+
75+
# consume_outputs return the captured values and resets the hook's state
76+
module_outputs = self.hooks.consume_outputs()
77+
loss_value = self.loss_function(module_outputs)
78+
return loss_value
79+
80+
def cleanup(self):
81+
r"""Garbage collection, mainly removing hooks."""
82+
self.hooks.remove_hooks()
83+
84+
# Targets are managed by ModuleOutputHooks; we mainly just want a convenient setter
85+
@property
86+
def targets(self):
87+
return self.hooks.targets
88+
89+
@targets.setter
90+
def targets(self, value):
91+
self.hooks.remove_hooks()
92+
self.hooks = ModuleOutputsHook(value)
93+
94+
def parameters(self):
95+
return self.input_param.parameters()
96+
97+
def optimize(
98+
self,
99+
stop_criteria: Optional[StopCriteria] = None,
100+
optimizer: Optional[optim.Optimizer] = None,
101+
):
102+
r"""Optimize input based on loss function and objectives.
103+
Args:
104+
stop_criteria (StopCriteria, optional): A function that is called
105+
every iteration and returns a bool that determines whether
106+
to stop the optimization.
107+
See captum.optim.typing.StopCriteria for details.
108+
optimizer (Optimizer, optional): An torch.optim.Optimizer used to
109+
optimize the input based on the loss function.
110+
Returns:
111+
*list* of *np.arrays* representing the **history**:
112+
- **history** (*list*):
113+
A list of loss values per iteration.
114+
Length of the list corresponds to the number of iterations
115+
"""
116+
stop_criteria = stop_criteria or n_steps(1024)
117+
optimizer = optimizer or optim.Adam(self.parameters(), lr=0.025)
118+
assert isinstance(optimizer, optim.Optimizer)
119+
120+
history = []
121+
step = 0
122+
while stop_criteria(step, self, history, optimizer):
123+
optimizer.zero_grad()
124+
loss_value = self.loss()
125+
history.append(loss_value.cpu().detach().numpy())
126+
(-1 * loss_value.mean()).backward()
127+
optimizer.step()
128+
step += 1
129+
130+
self.cleanup()
131+
return history
132+
133+
134+
def n_steps(n: int) -> StopCriteria:
135+
"""StopCriteria generator that uses number of steps as a stop criteria.
136+
Args:
137+
n (int): Number of steps to run optimization.
138+
Returns:
139+
*StopCriteria* callable
140+
"""
141+
pbar = tqdm(total=n, unit="step")
142+
143+
def continue_while(step, obj, history, optim):
144+
if len(history) > 0:
145+
pbar.set_postfix({"Objective": f"{history[-1].mean():.1f}"}, refresh=False)
146+
if step < n:
147+
pbar.update()
148+
return True
149+
else:
150+
pbar.close()
151+
return False
152+
153+
return continue_while
154+
155+
156+
def channel_activation(target: nn.Module, channel_index: int) -> LossFunction:
157+
def loss_function(targets_to_values: ModuleOutputMapping):
158+
activations = targets_to_values[target]
159+
assert activations is not None
160+
# ensure channel_index is valid
161+
assert channel_index < activations.shape[1]
162+
# assume NCHW
163+
# NOTE: not necessarily true e.g. for Linear layers
164+
# assert len(activations.shape) == 4
165+
return activations[:, channel_index, ...]
166+
167+
return loss_function
168+
169+
170+
def neuron_activation(
171+
target: nn.Module, channel_index: int, x: int = None, y: int = None
172+
) -> LossFunction:
173+
# ensure channel_index will be valid
174+
assert channel_index < target.out_channels
175+
176+
def loss_function(targets_to_values: ModuleOutputMapping):
177+
activations = targets_to_values[target]
178+
assert activations is not None
179+
assert len(activations.shape) == 4 # assume NCHW
180+
_, _, H, W = activations.shape
181+
182+
if x is None:
183+
_x = W // 2
184+
else:
185+
assert x < W
186+
_x = x
187+
188+
if y is None:
189+
_y = H // 2
190+
else:
191+
assert y < W
192+
_y = y
193+
194+
return activations[:, channel_index, _x, _y]
195+
196+
return loss_function
197+
198+
199+
def single_target_objective(
200+
target: nn.Module, loss_function: SingleTargetLossFunction
201+
) -> LossFunction:
202+
def inner(targets_to_values: ModuleOutputMapping):
203+
value = targets_to_values[target]
204+
return loss_function(value)
205+
206+
return inner
207+
208+
209+
class SingleTargetObjective(Objective):
210+
def __init__(
211+
self,
212+
model: nn.Module,
213+
target: nn.Module,
214+
loss_function: Callable[[torch.Tensor], torch.Tensor],
215+
):
216+
super(SingleTargetObjective, self).__init__(model=model, targets=[target])
217+
self.loss_function = loss_function
218+
219+
def loss(self, targets_to_values):
220+
assert len(self.targets) == 1
221+
target = self.targets[0]
222+
target_value = targets_to_values[target]
223+
loss_value = self.loss_function(target_value)
224+
self.history.append(loss_value.sum().cpu().detach().numpy().squeeze().item())
225+
return loss_value
226+
227+
228+
class MultiObjective(Objective):
229+
def __init__(
230+
self, objectives: List[Objective], weights: Optional[Iterable[float]] = None
231+
):
232+
model = objectives[0].model
233+
assert all(o.model == model for o in objectives)
234+
targets = (target for objective in objectives for target in objective.targets)
235+
super(MultiObjective, self).__init__(model=model, targets=targets)
236+
self.objectives = objectives
237+
self.weights = weights or len(objectives) * [1]
238+
239+
def loss(self, targets_to_values):
240+
loss = (
241+
objective.loss_function(targets_to_values) for objective in self.objectives
242+
)
243+
weighted = (loss * weight for weight in self.weights)
244+
loss_value = sum(weighted)
245+
self.history.append(loss_value.cpu().detach().numpy().squeeze().item())
246+
return loss_value
247+
248+
@property
249+
def histories(self) -> List[List[float]]:
250+
return [objective.history for objective in self.objectives]
251+
252+
253+
# class ChannelObjective(SingleTargetObjective):
254+
# def __init__(self, channel: int, *args, **kwargs):
255+
# loss_function = lambda activation: activation[:, channel, :, :].mean()
256+
# super(ChannelObjective, self).__init__(
257+
# *args, loss_function=loss_function, **kwargs
258+
# )

captum/optim/optim/output_hook.py renamed to captum/optim/_core/output_hook.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
from typing import Iterable
12
from warnings import warn
2-
from typing import Iterable, Dict, Optional
33

4-
import torch
54
import torch.nn as nn
65

7-
from clarity.pytorch import ModuleOutputMapping
6+
# from clarity.pytorch import ModuleOutputMapping
87

98

109
class AbortForwardException(Exception):
@@ -39,7 +38,8 @@ class ModuleReuseException(Exception):
3938

4039
class ModuleOutputsHook:
4140
def __init__(self, target_modules: Iterable[nn.Module]):
42-
self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None)
41+
# self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None)
42+
self.outputs = dict.fromkeys(target_modules, None)
4343
self.hooks = [
4444
module.register_forward_hook(self._forward_hook())
4545
for module in target_modules
@@ -59,17 +59,19 @@ def forward_hook(module, input, output):
5959
self.outputs[module] = output
6060
else:
6161
warn(
62-
f"Hook attached to {module} was called multiple times. As of 2019-11-22 please don't reuse nn.Modules in your models."
62+
f"Hook attached to {module} was called multiple times. "
63+
"As of 2019-11-22 please don't reuse nn.Modules in your models."
6364
)
6465
if self.is_ready:
6566
raise AbortForwardException("Forward hook called, all outputs saved.")
6667

6768
return forward_hook
6869

69-
def consume_outputs(self) -> ModuleOutputMapping:
70+
def consume_outputs(self): # -> ModuleOutputMapping:
7071
if not self.is_ready:
7172
warn(
72-
"Consume captured outputs, but not all requested target outputs have been captured yet!"
73+
"Consume captured outputs, but not all requested target outputs "
74+
"have been captured yet!"
7375
)
7476
outputs = self.outputs
7577
self._reset_outputs()
@@ -84,5 +86,5 @@ def remove_hooks(self):
8486
hook.remove()
8587

8688
def __del__(self):
87-
print(f"DEL HOOKS!: {list(self.outputs.keys())}")
89+
# print(f"DEL HOOKS!: {list(self.outputs.keys())}")
8890
self.remove_hooks()

captum/optim/_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# from .inception_v1 import googlenet

0 commit comments

Comments
 (0)