-
Notifications
You must be signed in to change notification settings - Fork 6
/
tent.py
125 lines (102 loc) · 4.47 KB
/
tent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# https://github.com/DequanWang/tent/blob/master/tent.py
from copy import deepcopy
import torch
import torch.nn as nn
import torch.jit
class Tent(nn.Module):
"""Tent adapts a model by entropy minimization during testing.
Once tented, a model adapts itself by updating on every forward.
"""
def __init__(self, model, optimizer, steps=1, episodic=False):
super().__init__()
self.model = model
self.optimizer = optimizer
self.steps = steps
assert steps > 0, "tent requires >= 1 step(s) to forward and update"
self.episodic = episodic
# note: if the model is never reset, like for continual adaptation,
# then skipping the state copy would save memory
self.model_state, self.optimizer_state = \
copy_model_and_optimizer(self.model, self.optimizer)
def forward(self, x):
if self.episodic:
self.reset()
for _ in range(self.steps):
outputs = forward_and_adapt(x, self.model, self.optimizer)
return outputs
def reset(self):
if self.model_state is None or self.optimizer_state is None:
raise Exception("cannot reset without saved model/optimizer state")
load_model_and_optimizer(self.model, self.optimizer,
self.model_state, self.optimizer_state)
@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
"""Entropy of softmax distribution from logits."""
return -(x.softmax(1) * x.log_softmax(1)).sum(1)
@torch.enable_grad() # ensure grads in possible no grad context for testing
def forward_and_adapt(x, model, optimizer):
"""Forward and adapt model on batch of data.
Measure entropy of the model prediction, take gradients, and update params.
"""
# forward
outputs = model(x)
# adapt
loss = softmax_entropy(outputs).mean(0)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return outputs
def collect_params(model):
"""Collect the affine scale + shift parameters from batch norms.
Walk the model's modules and collect all batch normalization parameters.
Return the parameters and their names.
Note: other choices of parameterization are possible!
"""
params = []
names = []
for nm, m in model.named_modules():
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
for np, p in m.named_parameters():
if np in ['weight', 'bias']: # weight is scale, bias is shift
params.append(p)
names.append(f"{nm}.{np}")
return params, names
def copy_model_and_optimizer(model, optimizer):
"""Copy the model and optimizer states for resetting after adaptation."""
model_state = deepcopy(model.state_dict())
optimizer_state = deepcopy(optimizer.state_dict())
return model_state, optimizer_state
def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
"""Restore the model and optimizer states from copies."""
model.load_state_dict(model_state, strict=True)
optimizer.load_state_dict(optimizer_state)
def configure_model(model):
"""Configure model for use with tent."""
# train mode, because tent optimizes the model to minimize entropy
model.train()
# disable grad, to (re-)enable only what tent updates
model.requires_grad_(False)
# configure norm for tent updates: enable grad + force batch statisics
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.requires_grad_(True)
# force use of batch stats in train and eval modes
m.track_running_stats = False
m.running_mean = None
m.running_var = None
if isinstance(m, (nn.GroupNorm, nn.LayerNorm)):
m.requires_grad_(True)
return model
def check_model(model):
"""Check model for compatability with tent."""
is_training = model.training
assert is_training, "tent needs train mode: call model.train()"
param_grads = [p.requires_grad for p in model.parameters()]
has_any_params = any(param_grads)
has_all_params = all(param_grads)
assert has_any_params, "tent needs params to update: " \
"check which require grad"
assert not has_all_params, "tent should not update all params: " \
"check which require grad"
has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
assert has_bn, "tent needs normalization for its optimization"