Skip to content

Commit

Permalink
update learning method and some dpbasicblocks
Browse files Browse the repository at this point in the history
  • Loading branch information
MekAkUActOR committed Dec 14, 2022
1 parent 89658b9 commit eb63008
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 63 deletions.
83 changes: 25 additions & 58 deletions code/deeppoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@
class DeepPoly:
def __init__(self, size, lb, ub):
self.slb = torch.cat([torch.diag(torch.ones(size)), torch.zeros(size).unsqueeze(1)], dim=1)
# print("Initialzing slb,", self.slb.shape)
self.sub = self.slb
# print("Initialzing sub,", self.sub.shape)
self.lb = lb
# print("Initialzing lb,", self.lb.shape, self.lb)
self.ub = ub
# print("Initialzing ub,",self.ub.shape, self.ub)

self.history = []
self.layers = 0
self.is_relu = False

def save(self):
# print("save ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
Expand All @@ -26,32 +21,20 @@ def save(self):
# print("save lb,", lb.shape)
ub = torch.cat([self.ub, torch.ones(1)])
# print("save ub,", ub.shape)
if self.is_relu:
# relu layer
slb = self.slb
sub = self.sub
else:
# other layers
keep_bias = torch.zeros(1, self.slb.shape[1])
keep_bias[0, self.slb.shape[1] - 1] = 1
# print("save keep_bias,", keep_bias.shape)
slb = torch.cat([self.slb, keep_bias], dim=0)
sub = torch.cat([self.sub, keep_bias], dim=0)

# print("save slb,", slb.shape)
# print("save sub,", sub.shape)
keep_bias = torch.zeros(1, self.slb.shape[1])
keep_bias[0, self.slb.shape[1] - 1] = 1
slb = torch.cat([self.slb, keep_bias], dim=0)
sub = torch.cat([self.sub, keep_bias], dim=0)

# layer num
self.layers += 1
# print("save layers,", self.layers)
# record each layer
self.history.append((slb, sub, lb, ub, self.is_relu))
self.history.append((slb, sub, lb, ub))
return self

def compute_verify_result(self, true_label):
self.save()
n = self.slb.shape[0] - 1
# print("slb,", n)
unit = torch.diag(torch.ones(n))
weights = torch.cat((-unit[:, :true_label], torch.ones(n, 1), -unit[:, true_label:], torch.zeros(n, 1)), dim=1)
# print("compute_verify_result,", weights.shape)
Expand All @@ -73,11 +56,8 @@ def resolve(self, constrains, layer, lower=True):
"""
# distinguish the sign of the coefficients of the constraints
pos_coeff = F.relu(constrains)
# print("resolve pos_coeff,", layer, pos_coeff.shape)
neg_coeff = -F.relu(-constrains)
# print("resolve neg_coeff,", layer, neg_coeff.shape)
layer_info = self.history[layer]
is_relu = layer_info[-1]
if layer == 0:
# layer_info[2],layer_info[3]: concrete lower and upper bound
lb, ub = layer_info[2], layer_info[3]
Expand All @@ -86,38 +66,21 @@ def resolve(self, constrains, layer, lower=True):
lb, ub = layer_info[0], layer_info[1]
if not lower:
lb, ub = ub, lb
# print("resolve lb,", layer, lb.shape)
# print("resolve ub,", layer, ub.shape)
if is_relu:
lb_diag, lb_bias = lb[0], lb[1]
ub_diag, ub_bias = ub[0], ub[1]
# print("resolve relu lb_diag,", layer, lb_diag.shape, lb_diag)
# print("resolve relu ub_diag,", layer, ub_diag.shape, ub_diag)
lb_bias = torch.cat([lb_bias, torch.ones(1)])
ub_bias = torch.cat([ub_bias, torch.ones(1)])
# print("resolve relu lb_bias,", layer, lb_bias.shape, lb_bias)
# print("resolve relu ub_bias,", layer, ub_bias.shape, ub_bias)

m1 = torch.cat([pos_coeff[:, :-1] * lb_diag, torch.matmul(pos_coeff, lb_bias).unsqueeze(1)], dim=1)
m2 = torch.cat([neg_coeff[:, :-1] * ub_diag, torch.matmul(neg_coeff, ub_bias).unsqueeze(1)], dim=1)
# print("resolve relu m1,", layer, m1.shape)
# print("resolve relu m2,", layer, m2.shape)
return m1 + m2
else:
m1 = torch.matmul(pos_coeff, lb)
m2 = torch.matmul(neg_coeff, ub)
# print("resolve linear m1,", layer, m1.shape)
# print("resolve linear m2,", layer, m2.shape)
return m1 + m2
m1 = torch.matmul(pos_coeff, lb)
m2 = torch.matmul(neg_coeff, ub)
return m1 + m2


class DPReLU(nn.Module):
def __init__(self, in_features):
super(DPReLU, self).__init__()
self.in_features = in_features
self.out_features = in_features
# self.alpha = torch.nn.Parameter(torch.ones(in_features))
self.alpha = torch.nn.Parameter(torch.rand(in_features) * 0.7854)
# self.alpha = torch.nn.Parameter(torch.ones(in_features) * 0.5)
# self.alpha = torch.nn.Parameter(torch.rand(in_features) * 0.7854)
self.alpha = torch.nn.Parameter(torch.rand(in_features))
# self.alpha = torch.nn.Parameter(torch.normal(mean=0.5, std=torch.ones(in_features)*0.5))
# self.alpha.data.clamp_(0, 1)
self.alpha.requires_grad = True

def forward(self, x):
Expand All @@ -144,7 +107,8 @@ def forward(self, x):
low < 0 < up
'''
# print("ALPHA: ", self.alpha)
slope_low_3 = torch.tan(self.alpha)
# slope_low_3 = torch.tan(self.alpha)
slope_low_3 = self.alpha
bias_low_3 = slope_low_3 * low - slope_low_3 * low
slope_up_3 = (F.relu(up) - F.relu(low)) / (up - low)
bias_up_3 = F.relu(up) - slope_up_3 * up
Expand All @@ -156,9 +120,9 @@ def forward(self, x):

x.lb = F.relu(low)
x.ub = F.relu(up)
x.slb = torch.cat([curr_slb.unsqueeze(0), curr_slb_bias.unsqueeze(0)], dim=0)
x.sub = torch.cat([curr_sub.unsqueeze(0), curr_sub_bias.unsqueeze(0)], dim=0)
x.is_relu = True
x.slb = torch.cat([torch.diag(curr_slb), curr_slb_bias.unsqueeze(1)], dim=1)
x.sub = torch.cat([torch.diag(curr_sub), curr_sub_bias.unsqueeze(1)], dim=1)
# x.is_relu = True
return x


Expand All @@ -178,7 +142,7 @@ def forward(self, x):
x.kernel_size = self.kernel_size
x.stride = self.stride
x.padding = self.padding
x.is_relu = False
# x.is_relu = False
return x


Expand All @@ -204,7 +168,7 @@ def forward(self, x):
for i in range(x.layers, 0, -1):
x.lb = x.resolve(x.lb, i - 1, lower=True)
x.ub = x.resolve(x.ub, i - 1, lower=False)
x.is_relu = False
# x.is_relu = False
return x


Expand Down Expand Up @@ -237,10 +201,11 @@ def __init__(self, nested: nn.Conv2d, in_feature):
self.out_features = self.out_channels * \
math.floor((self.in_features[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1) * \
math.floor((self.in_features[2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1)
self.weightMtrx = self.get_sparse_conv_weights_matrix(self.weight, self.in_features, self.bias, self.stride[0], self.padding[0])

def forward(self, x):
x.save()
init_slb = self.get_sparse_conv_weights_matrix(self.weight, self.in_features, self.bias, self.stride[0], self.padding[0])
init_slb = self.weightMtrx
# print("Conv init_slb", init_slb.shape, init_slb)
x.lb = init_slb
x.ub = init_slb
Expand All @@ -251,7 +216,7 @@ def forward(self, x):
x.ub = x.resolve(x.ub, i - 1, lower=False)
# print("Conv init_slb x.lb", x.lb)
# print("Conv init_slb x.ub", x.ub)
x.is_relu = False
# x.is_relu = False
return x

@staticmethod
Expand Down Expand Up @@ -372,6 +337,8 @@ def __init__(self, nested: resnet.BasicBlock, in_feature):
self.bn = nested.bn
self.kernel = nested.kernel
self.expansion = nested.expansion
self.path_a = nested.path_a
self.path_b = nested.path_b
self.paths = []
for modu in nested.modules():
if type(modu) == nn.Sequential:
Expand Down
2 changes: 2 additions & 0 deletions code/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(
layers_a.append(nn.BatchNorm2d(self.expansion * planes))
path_a = nn.Sequential(*layers_a)
self.out_dim = in_dim
self.path_a = path_a
self.path_b = path_b
super(BasicBlock, self).__init__(path_a, path_b)

def _getShapeConv(
Expand Down
16 changes: 11 additions & 5 deletions code/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

DEVICE = 'cpu'
DTYPE = torch.float32
LR = 0.05
num_iter = 25

LR = 0.7
num_iter = 10
lr_decay = 0.8
lr_destep = 1

# class DeepPoly:
# def __init__(self, lb, ub, lexpr, uexpr) -> None:
Expand Down Expand Up @@ -284,7 +285,9 @@ def analyze(net, inputs, eps, true_label):

verifiable_net = verifiable(net, inputs.reshape(-1))
optimizer = optim.Adam(verifiable_net.parameters(), lr=LR)
for i in range(num_iter):
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_destep, gamma=lr_decay)
for epoch in range(num_iter):
# print(scheduler.get_last_lr())
optimizer.zero_grad()
verifier_output = verifiable_net(DeepPoly(low_bound.shape[0], low_bound, up_bound))
res = verifier_output.compute_verify_result(true_label)
Expand All @@ -295,8 +298,11 @@ def analyze(net, inputs, eps, true_label):
optimizer.step()
for p in verifiable_net.parameters():
if p.requires_grad:
p.data.clamp_(0, 0.7854)
p.data.clamp_(0, 1)
if scheduler.get_last_lr()[0] > 0.2:
scheduler.step()

optimizer.zero_grad()
verifier_output = verifiable_net(DeepPoly(low_bound.shape[0], low_bound, up_bound))
res = verifier_output.compute_verify_result(true_label)
if (res > 0).all():
Expand Down

0 comments on commit eb63008

Please sign in to comment.