Skip to content

Commit e72ec5d

Browse files
Ludwig SchubertNarineK
authored andcommitted
DO NOT MERGE
WIP optim module
1 parent 91dd756 commit e72ec5d

23 files changed

+2021
-0
lines changed

captum/optim/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Captum "optim" module
2+
3+
This is a WIP PR to integrate existing feature visualization code from the authors of `tensorflow/lucid` into captum.
4+
It is also an opportunity to review which parts of such interpretability tools still feel rough to implement in a system like PyTorch, and to make suggetsions to the core PyTorch team for how to improve these aspects.
5+
6+
## Roadmap
7+
8+
* unify API with Captum API: a single class that's callable per "technique"(? check for details before implementing)
9+
* Consider if we need an abstraction around "an optimization process" (in terms of stopping criteria, reporting losses, etc) or if there are sufficiently strong conventions in PyTorch land for such tasks
10+
* integrate Eli's FFT param changes (mostly for simplification)
11+
* make a table of PyTorch interpretability tools for readme?
12+
* do we need image viewing helpers and io helpers or throw those out?
13+
* can we integrate paper references closer with the code?

captum/optim/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Dict, Optional, Union, Callable, Iterable
2+
from typing_extensions import Protocol
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
ParametersForOptimizers = Iterable[Union[torch.Tensor, Dict[str, torch.tensor]]]
8+
9+
10+
class HasLoss(Protocol):
11+
def loss(self) -> torch.Tensor:
12+
...
13+
14+
15+
class Parameterized(Protocol):
16+
parameters: ParametersForOptimizers
17+
18+
19+
class Objective(Parameterized, HasLoss):
20+
def cleanup(self):
21+
pass
22+
23+
24+
ModuleOutputMapping = Dict[nn.Module, Optional[torch.Tensor]]
25+
26+
StopCriteria = Callable[[int, Objective, torch.optim.Optimizer], bool]
27+

captum/optim/_scrap_and_testing.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import numpy as np
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
import torchvision
7+
import requests
8+
from PIL import Image
9+
from IPython.display import display
10+
11+
from clarity.pytorch.inception_v1 import googlenet
12+
from lucid.misc.io import show, load, save
13+
from lucid.modelzoo.other_models import InceptionV1
14+
15+
# get a test image
16+
img_url = (
17+
"https://lucid-static.storage.googleapis.com/building-blocks/examples/dog_cat.png"
18+
)
19+
img_tf = load(img_url)
20+
img_pt = torch.as_tensor(img_tf.transpose(2, 0, 1))[None, ...]
21+
img_pil = Image.open(requests.get(img_url, stream=True).raw)
22+
23+
# instantiate ported model
24+
net = googlenet(pretrained=True)
25+
26+
# get predictions
27+
out = net(img_pt)
28+
logits = out.detach().numpy()[0]
29+
top_k = np.argsort(-logits)[:5]
30+
31+
# load labels
32+
labels = load(InceptionV1.labels_path, split=True)
33+
34+
# show predictions
35+
for i, k in enumerate(top_k):
36+
prediction = logits[k]
37+
label = labels[k]
38+
print(f"{i}: {label} ({prediction*100:.2f}%)")
39+
40+
# transforms
41+
42+
43+
# def build_grid(source_size, target_size):
44+
# k = float(target_size) / float(source_size)
45+
# direct = (
46+
# torch.linspace(0, k, target_size)
47+
# .unsqueeze(0)
48+
# .repeat(target_size, 1)
49+
# .unsqueeze(-1)
50+
# )
51+
# full = torch.cat([direct, direct.transpose(1, 0)], dim=2).unsqueeze(0)
52+
# return full.cuda()
53+
54+
55+
# def random_crop_grid(x, grid):
56+
# d = x.size(2) - grid.size(1)
57+
# grid = grid.repeat(x.size(0), 1, 1, 1).cuda()
58+
# # Add random shifts by x
59+
# grid[:, :, :, 0] += torch.FloatTensor(x.size(0)).cuda().random_(0, d).unsqueeze(
60+
# -1
61+
# ).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) / x.size(2)
62+
# # Add random shifts by y
63+
# grid[:, :, :, 1] += torch.FloatTensor(x.size(0)).cuda().random_(0, d).unsqueeze(
64+
# -1
65+
# ).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) / x.size(2)
66+
# return grid
67+
68+
69+
# # We want to crop a 80x80 image randomly for our batch
70+
# # Building central crop of 80 pixel size
71+
# grid_source = build_grid(224, 80)
72+
# # Make radom shift for each batch
73+
# grid_shifted = random_crop_grid(batch, grid_source)
74+
# # Sample using grid sample
75+
# sampled_batch = F.grid_sample(batch, grid_shifted)
76+
77+
78+
from clarity.pytorch.transform import RandomSpatialJitter, RandomUpsample
79+
80+
# crop = torchvision.transforms.RandomCrop(
81+
# 224, padding=34, pad_if_needed=True, padding_mode="reflect"
82+
# )
83+
jitter = RandomSpatialJitter(16)
84+
ups = RandomUpsample()
85+
for i in range(10):
86+
cropped = ups(img_pt)
87+
show(cropped.numpy()[0].transpose(1, 2, 0))
88+
# display(cropped)
89+
90+
91+
# result = param().cpu().detach().numpy()[0].transpose(1, 2, 0)
92+
# loss_curve = objective.history
93+
94+
# 2019-11-21 notes from Pytorch team
95+
# Set up model
96+
# net = googlenet(pretrained=True)
97+
# parameterization = Image() # TODO: make size adjustable, currently hardcoded
98+
# input_image = parameterization()
99+
100+
# writer = SummaryWriter()
101+
# writer.add_graph(net, (input_image,))
102+
# writer.close()
103+
104+
# Specify target module / "objective"
105+
# target_module = net.mixed3b._pool_reduce[1]
106+
# target_channel = 54
107+
# hook = OutputHook(target_module) # TODO: investigate detach on rerun
108+
# parameterization = Image() # TODO: make size adjustable, currently hardcoded
109+
# optimizer = optim.Adam(parameterization.parameters, lr=0.025)
110+
111+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
112+
# net = net.to(device)
113+
# parameterization = parameterization.to(device)
114+
# for i in range(1000):
115+
# optimizer.zero_grad()
116+
117+
# # forward pass through entire net
118+
# input_image = parameterization()
119+
# with suppress(AbortForwardException):
120+
# _ = net(input_image.to(device))
121+
122+
# # activations were stored during forward pass
123+
# assert hook.saved_output is not None
124+
# loss = -hook.saved_output[:, target_channel, :, :].sum() # channel 13
125+
126+
# loss.backward()
127+
# optimizer.step()
128+
129+
# if i % 100 == 0:
130+
# print("Loss: ", -loss.cpu().detach().numpy())
131+
# url = show(
132+
# parameterization.raw_image.cpu()
133+
# .detach()
134+
# .numpy()[0]
135+
# .transpose(1, 2, 0)
136+
# )
137+
138+
# traced_net = torch.jit.trace(net, example_inputs=(input_image,))
139+
# print(traced_net.graph)

captum/optim/io/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .io import show

captum/optim/io/fixtures.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
3+
# TODO: use imageio to redo load and avoid TF dependency
4+
from lucid.misc.io import load
5+
6+
DOG_CAT_URL = (
7+
"https://lucid-static.storage.googleapis.com/building-blocks/examples/dog_cat.png"
8+
)
9+
10+
11+
def image(url: str = DOG_CAT_URL):
12+
img_np = load(url)
13+
return torch.as_tensor(img_np.transpose(2, 0, 1))

captum/optim/io/formatters.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from io import BytesIO
2+
3+
import torch
4+
from torchvision import transforms
5+
6+
from IPython import display, get_ipython
7+
8+
9+
def tensor_jpeg(tensor: torch.Tensor):
10+
if tensor.dim() == 3:
11+
pil_image = transforms.ToPILImage()(tensor.cpu().detach()).convert("RGB")
12+
buffer = BytesIO()
13+
pil_image.save(buffer, format="jpeg")
14+
data = buffer.getvalue()
15+
return data
16+
else:
17+
return tensor
18+
19+
20+
def register_formatters():
21+
jpeg_formatter = get_ipython().display_formatter.formatters["image/jpeg"]
22+
jpeg_formatter.for_type(torch.Tensor, tensor_jpeg)

captum/optim/io/io.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# TODO: redo show using display or register handler for jupyter display directly
2+
# maybe we could even have subtypes of tensors that are "ImageTensors" or "ActivationTensors" etc
3+
from lucid.misc.io import show as lucid_show
4+
5+
6+
def show(thing):
7+
if len(thing.shape) == 3:
8+
numpy_thing = thing.cpu().detach().numpy().transpose(1, 2, 0)
9+
elif len(thing.shape) == 4:
10+
numpy_thing = thing.cpu().detach().numpy()[0].transpose(1, 2, 0)
11+
lucid_show(numpy_thing)

captum/optim/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .inception_v1 import googlenet

captum/optim/models/conv2d.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import math
5+
6+
7+
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
8+
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
9+
10+
11+
def _get_padding(kernel_size, stride=1, dilation=1, **_):
12+
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
13+
return padding
14+
15+
16+
def _calc_same_pad(i, k, s, d):
17+
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
18+
19+
20+
def _split_channels(num_chan, num_groups):
21+
split = [num_chan // num_groups for _ in range(num_groups)]
22+
split[0] += num_chan - sum(split)
23+
return split
24+
25+
26+
class Conv2dSame(nn.Conv2d):
27+
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
28+
"""
29+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
30+
padding=0, dilation=1, groups=1, bias=True):
31+
super(Conv2dSame, self).__init__(
32+
in_channels, out_channels, kernel_size, stride, 0, dilation,
33+
groups, bias)
34+
35+
def forward(self, x):
36+
ih, iw = x.size()[-2:]
37+
kh, kw = self.weight.size()[-2:]
38+
pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0])
39+
pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1])
40+
if pad_h > 0 or pad_w > 0:
41+
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
42+
return F.conv2d(x, self.weight, self.bias, self.stride,
43+
self.padding, self.dilation, self.groups)
44+
45+
46+
# def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
47+
# padding = kwargs.pop('padding', '')
48+
# kwargs.setdefault('bias', False)
49+
# if isinstance(padding, str):
50+
# # for any string padding, the padding will be calculated for you, one of three ways
51+
# padding = padding.lower()
52+
# if padding == 'same':
53+
# # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
54+
# if _is_static_pad(kernel_size, **kwargs):
55+
# # static case, no extra overhead
56+
# padding = _get_padding(kernel_size, **kwargs)
57+
# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
58+
# else:
59+
# # dynamic padding
60+
# return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
61+
# elif padding == 'valid':
62+
# # 'VALID' padding, same as padding=0
63+
# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
64+
# else:
65+
# # Default to PyTorch style 'same'-ish symmetric padding
66+
# padding = _get_padding(kernel_size, **kwargs)
67+
# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
68+
# else:
69+
# # padding was specified as a number or pair
70+
# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
71+
72+
73+
# class MixedConv2d(nn.Module):
74+
# """ Mixed Grouped Convolution
75+
# Based on MDConv and GroupedConv in MixNet impl:
76+
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
77+
# """
78+
79+
# def __init__(self, in_channels, out_channels, kernel_size=3,
80+
# stride=1, padding='', dilated=False, depthwise=False, **kwargs):
81+
# super(MixedConv2d, self).__init__()
82+
83+
# kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
84+
# num_groups = len(kernel_size)
85+
# in_splits = _split_channels(in_channels, num_groups)
86+
# out_splits = _split_channels(out_channels, num_groups)
87+
# for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
88+
# d = 1
89+
# # FIXME make compat with non-square kernel/dilations/strides
90+
# if stride == 1 and dilated:
91+
# d, k = (k - 1) // 2, 3
92+
# conv_groups = out_ch if depthwise else 1
93+
# # use add_module to keep key space clean
94+
# self.add_module(
95+
# str(idx),
96+
# conv2d_pad(
97+
# in_ch, out_ch, k, stride=stride,
98+
# padding=padding, dilation=d, groups=conv_groups, **kwargs)
99+
# )
100+
# self.splits = in_splits
101+
102+
# def forward(self, x):
103+
# x_split = torch.split(x, self.splits, 1)
104+
# x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
105+
# x = torch.cat(x_out, 1)
106+
# return x
107+
108+
109+
# # helper method
110+
# def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
111+
# assert 'groups' not in kwargs # only use 'depthwise' bool arg
112+
# if isinstance(kernel_size, list):
113+
# # We're going to use only lists for defining the MixedConv2d kernel groups,
114+
# # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
115+
# return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
116+
# else:
117+
# depthwise = kwargs.pop('depthwise', False)
118+
# groups = out_chs if depthwise else 1
119+
# return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)

0 commit comments

Comments
 (0)