|
| 1 | +"""captum.optim.objectives.""" |
| 2 | + |
| 3 | +from contextlib import suppress |
| 4 | +from typing import Callable, Iterable, List, Optional |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | +import torch.optim as optim |
| 9 | +from tqdm.auto import tqdm |
| 10 | + |
| 11 | +from captum.optim._core.output_hook import AbortForwardException, ModuleOutputsHook |
| 12 | +from captum.optim._param.image.images import InputParameterization, NaturalImage |
| 13 | +from captum.optim._param.image.transform import RandomAffine |
| 14 | +from captum.optim._utils.typing import ( |
| 15 | + LossFunction, |
| 16 | + ModuleOutputMapping, |
| 17 | + Objective, |
| 18 | + Parameterized, |
| 19 | + SingleTargetLossFunction, |
| 20 | + StopCriteria, |
| 21 | +) |
| 22 | + |
| 23 | + |
| 24 | +class InputOptimization(Objective, Parameterized): |
| 25 | + """ |
| 26 | + Core function that optimizes an input to maximize a target (aka objective). |
| 27 | + This is similar to gradient-based methods for adversarial examples, such |
| 28 | + as FGSM. The code for this was based on the implementation by the authors of Lucid. |
| 29 | + For more details, see the following: |
| 30 | + https://github.com/tensorflow/lucid |
| 31 | + https://distill.pub/2017/feature-visualization/ |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + model: nn.Module, |
| 37 | + input_param: Optional[InputParameterization], |
| 38 | + transform: Optional[nn.Module], |
| 39 | + target_modules: Iterable[nn.Module], |
| 40 | + loss_function: LossFunction, |
| 41 | + ): |
| 42 | + r""" |
| 43 | + Args: |
| 44 | + model (nn.Module): The reference to PyTorch model instance. |
| 45 | + input_param (nn.Module, optional): A module that generates an input, |
| 46 | + consumed by the model. |
| 47 | + transform (nn.Module, optional): A module that transforms or preprocesses |
| 48 | + the input before being passed to the model. |
| 49 | + target_modules (iterable of nn.Module): A list of targets, objectives that |
| 50 | + are used to compute the loss function. |
| 51 | + loss_function (callable): The loss function to minimize during optimization |
| 52 | + optimization. |
| 53 | + """ |
| 54 | + self.model = model |
| 55 | + self.hooks = ModuleOutputsHook(target_modules) |
| 56 | + self.input_param = input_param or NaturalImage((224, 224)) |
| 57 | + self.transform = transform or RandomAffine(scale=True, translate=True) |
| 58 | + self.loss_function = loss_function |
| 59 | + |
| 60 | + def loss(self) -> torch.Tensor: |
| 61 | + r"""Compute loss value for current iteration. |
| 62 | + Returns: |
| 63 | + *tensor* representing **loss**: |
| 64 | + - **loss** (*tensor*): |
| 65 | + Size of the tensor corresponds to the targets passed. |
| 66 | + """ |
| 67 | + image = self.input_param()._t[None, ...] |
| 68 | + |
| 69 | + if self.transform: |
| 70 | + image = self.transform(image) |
| 71 | + |
| 72 | + with suppress(AbortForwardException): |
| 73 | + _unreachable = self.model(image) # noqa: F841 |
| 74 | + |
| 75 | + # consume_outputs return the captured values and resets the hook's state |
| 76 | + module_outputs = self.hooks.consume_outputs() |
| 77 | + loss_value = self.loss_function(module_outputs) |
| 78 | + return loss_value |
| 79 | + |
| 80 | + def cleanup(self): |
| 81 | + r"""Garbage collection, mainly removing hooks.""" |
| 82 | + self.hooks.remove_hooks() |
| 83 | + |
| 84 | + # Targets are managed by ModuleOutputHooks; we mainly just want a convenient setter |
| 85 | + @property |
| 86 | + def targets(self): |
| 87 | + return self.hooks.targets |
| 88 | + |
| 89 | + @targets.setter |
| 90 | + def targets(self, value): |
| 91 | + self.hooks.remove_hooks() |
| 92 | + self.hooks = ModuleOutputsHook(value) |
| 93 | + |
| 94 | + def parameters(self): |
| 95 | + return self.input_param.parameters() |
| 96 | + |
| 97 | + def optimize( |
| 98 | + self, |
| 99 | + stop_criteria: Optional[StopCriteria] = None, |
| 100 | + optimizer: Optional[optim.Optimizer] = None, |
| 101 | + ): |
| 102 | + r"""Optimize input based on loss function and objectives. |
| 103 | + Args: |
| 104 | + stop_criteria (StopCriteria, optional): A function that is called |
| 105 | + every iteration and returns a bool that determines whether |
| 106 | + to stop the optimization. |
| 107 | + See captum.optim.typing.StopCriteria for details. |
| 108 | + optimizer (Optimizer, optional): An torch.optim.Optimizer used to |
| 109 | + optimize the input based on the loss function. |
| 110 | + Returns: |
| 111 | + *list* of *np.arrays* representing the **history**: |
| 112 | + - **history** (*list*): |
| 113 | + A list of loss values per iteration. |
| 114 | + Length of the list corresponds to the number of iterations |
| 115 | + """ |
| 116 | + stop_criteria = stop_criteria or n_steps(1024) |
| 117 | + optimizer = optimizer or optim.Adam(self.parameters(), lr=0.025) |
| 118 | + assert isinstance(optimizer, optim.Optimizer) |
| 119 | + |
| 120 | + history = [] |
| 121 | + step = 0 |
| 122 | + while stop_criteria(step, self, history, optimizer): |
| 123 | + optimizer.zero_grad() |
| 124 | + loss_value = self.loss() |
| 125 | + history.append(loss_value.cpu().detach().numpy()) |
| 126 | + (-1 * loss_value.mean()).backward() |
| 127 | + optimizer.step() |
| 128 | + step += 1 |
| 129 | + |
| 130 | + self.cleanup() |
| 131 | + return history |
| 132 | + |
| 133 | + |
| 134 | +def n_steps(n: int) -> StopCriteria: |
| 135 | + """StopCriteria generator that uses number of steps as a stop criteria. |
| 136 | + Args: |
| 137 | + n (int): Number of steps to run optimization. |
| 138 | + Returns: |
| 139 | + *StopCriteria* callable |
| 140 | + """ |
| 141 | + pbar = tqdm(total=n, unit="step") |
| 142 | + |
| 143 | + def continue_while(step, obj, history, optim): |
| 144 | + if len(history) > 0: |
| 145 | + pbar.set_postfix({"Objective": f"{history[-1].mean():.1f}"}, refresh=False) |
| 146 | + if step < n: |
| 147 | + pbar.update() |
| 148 | + return True |
| 149 | + else: |
| 150 | + pbar.close() |
| 151 | + return False |
| 152 | + |
| 153 | + return continue_while |
| 154 | + |
| 155 | + |
| 156 | +def channel_activation(target: nn.Module, channel_index: int) -> LossFunction: |
| 157 | + def loss_function(targets_to_values: ModuleOutputMapping): |
| 158 | + activations = targets_to_values[target] |
| 159 | + assert activations is not None |
| 160 | + # ensure channel_index is valid |
| 161 | + assert channel_index < activations.shape[1] |
| 162 | + # assume NCHW |
| 163 | + # NOTE: not necessarily true e.g. for Linear layers |
| 164 | + # assert len(activations.shape) == 4 |
| 165 | + return activations[:, channel_index, ...] |
| 166 | + |
| 167 | + return loss_function |
| 168 | + |
| 169 | + |
| 170 | +def neuron_activation( |
| 171 | + target: nn.Module, channel_index: int, x: int = None, y: int = None |
| 172 | +) -> LossFunction: |
| 173 | + # ensure channel_index will be valid |
| 174 | + assert channel_index < target.out_channels |
| 175 | + |
| 176 | + def loss_function(targets_to_values: ModuleOutputMapping): |
| 177 | + activations = targets_to_values[target] |
| 178 | + assert activations is not None |
| 179 | + assert len(activations.shape) == 4 # assume NCHW |
| 180 | + _, _, H, W = activations.shape |
| 181 | + |
| 182 | + if x is None: |
| 183 | + _x = W // 2 |
| 184 | + else: |
| 185 | + assert x < W |
| 186 | + _x = x |
| 187 | + |
| 188 | + if y is None: |
| 189 | + _y = H // 2 |
| 190 | + else: |
| 191 | + assert y < W |
| 192 | + _y = y |
| 193 | + |
| 194 | + return activations[:, channel_index, _x, _y] |
| 195 | + |
| 196 | + return loss_function |
| 197 | + |
| 198 | + |
| 199 | +def single_target_objective( |
| 200 | + target: nn.Module, loss_function: SingleTargetLossFunction |
| 201 | +) -> LossFunction: |
| 202 | + def inner(targets_to_values: ModuleOutputMapping): |
| 203 | + value = targets_to_values[target] |
| 204 | + return loss_function(value) |
| 205 | + |
| 206 | + return inner |
| 207 | + |
| 208 | + |
| 209 | +class SingleTargetObjective(Objective): |
| 210 | + def __init__( |
| 211 | + self, |
| 212 | + model: nn.Module, |
| 213 | + target: nn.Module, |
| 214 | + loss_function: Callable[[torch.Tensor], torch.Tensor], |
| 215 | + ): |
| 216 | + super(SingleTargetObjective, self).__init__(model=model, targets=[target]) |
| 217 | + self.loss_function = loss_function |
| 218 | + |
| 219 | + def loss(self, targets_to_values): |
| 220 | + assert len(self.targets) == 1 |
| 221 | + target = self.targets[0] |
| 222 | + target_value = targets_to_values[target] |
| 223 | + loss_value = self.loss_function(target_value) |
| 224 | + self.history.append(loss_value.sum().cpu().detach().numpy().squeeze().item()) |
| 225 | + return loss_value |
| 226 | + |
| 227 | + |
| 228 | +class MultiObjective(Objective): |
| 229 | + def __init__( |
| 230 | + self, objectives: List[Objective], weights: Optional[Iterable[float]] = None |
| 231 | + ): |
| 232 | + model = objectives[0].model |
| 233 | + assert all(o.model == model for o in objectives) |
| 234 | + targets = (target for objective in objectives for target in objective.targets) |
| 235 | + super(MultiObjective, self).__init__(model=model, targets=targets) |
| 236 | + self.objectives = objectives |
| 237 | + self.weights = weights or len(objectives) * [1] |
| 238 | + |
| 239 | + def loss(self, targets_to_values): |
| 240 | + loss = ( |
| 241 | + objective.loss_function(targets_to_values) for objective in self.objectives |
| 242 | + ) |
| 243 | + weighted = (loss * weight for weight in self.weights) |
| 244 | + loss_value = sum(weighted) |
| 245 | + self.history.append(loss_value.cpu().detach().numpy().squeeze().item()) |
| 246 | + return loss_value |
| 247 | + |
| 248 | + @property |
| 249 | + def histories(self) -> List[List[float]]: |
| 250 | + return [objective.history for objective in self.objectives] |
| 251 | + |
| 252 | + |
| 253 | +# class ChannelObjective(SingleTargetObjective): |
| 254 | +# def __init__(self, channel: int, *args, **kwargs): |
| 255 | +# loss_function = lambda activation: activation[:, channel, :, :].mean() |
| 256 | +# super(ChannelObjective, self).__init__( |
| 257 | +# *args, loss_function=loss_function, **kwargs |
| 258 | +# ) |
0 commit comments