-
Notifications
You must be signed in to change notification settings - Fork 6
/
sar.py
186 lines (154 loc) · 6.97 KB
/
sar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
Copyright to SAR Authors, ICLR 2023 Oral (notable-top-5%)
built upon on Tent code.
"""
from copy import deepcopy
import torch
import torch.nn as nn
import torch.jit
import math
import numpy as np
def update_ema(ema, new_data):
if ema is None:
return new_data
else:
with torch.no_grad():
return 0.9 * ema + (1 - 0.9) * new_data
class SAR(nn.Module):
"""SAR online adapts a model by Sharpness-Aware and Reliable entropy minimization during testing.
Once SARed, a model adapts itself by updating on every forward.
"""
def __init__(self, model, optimizer, steps=1, episodic=False, margin_e0=0.4*math.log(1000), reset_constant_em=0.2):
super().__init__()
self.model = model
self.optimizer = optimizer
self.steps = steps
assert steps > 0, "SAR requires >= 1 step(s) to forward and update"
self.episodic = episodic
self.margin_e0 = margin_e0 # margin E_0 for reliable entropy minimization, Eqn. (2)
self.reset_constant_em = reset_constant_em # threshold e_m for model recovery scheme
self.ema = None # to record the moving average of model output entropy, as model recovery criteria
# note: if the model is never reset, like for continual adaptation,
# then skipping the state copy would save memory
self.model_state, self.optimizer_state = \
copy_model_and_optimizer(self.model, self.optimizer)
def forward(self, x):
if self.episodic:
self.reset()
for _ in range(self.steps):
outputs, ema, reset_flag = forward_and_adapt_sar(x, self.model, self.optimizer, self.margin_e0, self.reset_constant_em, self.ema)
if reset_flag:
self.reset()
self.ema = ema # update moving average value of loss
return outputs
def reset(self):
if self.model_state is None or self.optimizer_state is None:
raise Exception("cannot reset without saved model/optimizer state")
load_model_and_optimizer(self.model, self.optimizer,
self.model_state, self.optimizer_state)
self.ema = None
@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
"""Entropy of softmax distribution from logits."""
return -(x.softmax(1) * x.log_softmax(1)).sum(1)
@torch.enable_grad() # ensure grads in possible no grad context for testing
def forward_and_adapt_sar(x, model, optimizer, margin, reset_constant, ema):
"""Forward and adapt model input data.
Measure entropy of the model prediction, take gradients, and update params.
"""
optimizer.zero_grad()
# forward
outputs = model(x)
# adapt
# filtering reliable samples/gradients for further adaptation; first time forward
entropys = softmax_entropy(outputs)
filter_ids_1 = torch.where(entropys < margin)
entropys = entropys[filter_ids_1]
loss = entropys.mean(0)
loss.backward()
optimizer.first_step(zero_grad=True) # compute \hat{\epsilon(\Theta)} for first order approximation, Eqn. (4)
entropys2 = softmax_entropy(model(x))
entropys2 = entropys2[filter_ids_1] # second time forward
loss_second_value = entropys2.clone().detach().mean(0)
filter_ids_2 = torch.where(entropys2 < margin) # here filtering reliable samples again, since model weights have been changed to \Theta+\hat{\epsilon(\Theta)}
loss_second = entropys2[filter_ids_2].mean(0)
if not np.isnan(loss_second.item()):
ema = update_ema(ema, loss_second.item()) # record moving average loss values for model recovery
# second time backward, update model weights using gradients at \Theta+\hat{\epsilon(\Theta)}
loss_second.backward()
optimizer.second_step(zero_grad=True)
# perform model recovery
reset_flag = False
if ema is not None:
if ema < 0.2:
print("ema < 0.2, now reset the model")
reset_flag = True
return outputs, ema, reset_flag
def collect_params(model):
"""Collect the affine scale + shift parameters from norm layers.
Walk the model's modules and collect all normalization parameters.
Return the parameters and their names.
Note: other choices of parameterization are possible!
"""
params = []
names = []
for nm, m in model.named_modules():
# skip top layers for adaptation: layer4 for ResNets and blocks9-11 for Vit-Base
if 'layer4' in nm:
continue
if 'blocks.9' in nm:
continue
if 'blocks.10' in nm:
continue
if 'blocks.11' in nm:
continue
if 'norm.' in nm:
continue
if nm in ['norm']:
continue
if isinstance(m, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
for np, p in m.named_parameters():
if np in ['weight', 'bias']: # weight is scale, bias is shift
params.append(p)
names.append(f"{nm}.{np}")
return params, names
def copy_model_and_optimizer(model, optimizer):
"""Copy the model and optimizer states for resetting after adaptation."""
model_state = deepcopy(model.state_dict())
optimizer_state = deepcopy(optimizer.state_dict())
return model_state, optimizer_state
def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
"""Restore the model and optimizer states from copies."""
model.load_state_dict(model_state, strict=True)
optimizer.load_state_dict(optimizer_state)
def configure_model(model):
"""Configure model for use with SAR."""
# train mode, because SAR optimizes the model to minimize entropy
model.train()
# disable grad, to (re-)enable only what SAR updates
model.requires_grad_(False)
# configure norm for SAR updates: enable grad + force batch statisics (this only for BN models)
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.requires_grad_(True)
# force use of batch stats in train and eval modes
m.track_running_stats = False
m.running_mean = None
m.running_var = None
# LayerNorm and GroupNorm for ResNet-GN and Vit-LN models
if isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
m.requires_grad_(True)
return model
def check_model(model):
"""Check model for compatability with SAR."""
is_training = model.training
assert is_training, "SAR needs train mode: call model.train()"
param_grads = [p.requires_grad for p in model.parameters()]
has_any_params = any(param_grads)
has_all_params = all(param_grads)
assert has_any_params, "SAR needs params to update: " \
"check which require grad"
assert not has_all_params, "SAR should not update all params: " \
"check which require grad"
has_norm = any([isinstance(m, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)) for m in model.modules()])
assert has_norm, "SAR needs normalization layer parameters for its optimization"