Skip to content

Commit

Permalink
DeepPoly transformer for fc, still exists bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Haowe committed Nov 21, 2022
1 parent 6e7ba51 commit 3297c50
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 2 deletions.
150 changes: 150 additions & 0 deletions code/deeppoly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class DeepPoly:
def __init__(self, size, low_bound, up_bound):
self.lowbound = low_bound
self.upbound = up_bound
self.sym_lowbound = torch.cat([torch.diag(torch.ones(size)), torch.zeros(size).unsqueeze(1)], dim=1)
self.sym_upbound = self.sym_lowbound
self.history = []
self.layers = 0
self.is_relu = False

def save(self):
""" Save constraints for the backsubstitution """
low_bound = torch.cat([self.lowbound, torch.ones(1)])
up_bound = torch.cat([self.upbound, torch.ones(1)])
if self.is_relu:
sym_lowbound = self.sym_lowbound
sym_upbound = self.sym_upbound
else:
# other layers
keep_bias = torch.zeros(1, self.sym_lowbound.shape[1])
keep_bias[0, self.sym_lowbound.shape[1] - 1] = 1
sym_lowbound = torch.cat([self.sym_lowbound, keep_bias], dim=0)
sym_upbound = torch.cat([self.sym_upbound, keep_bias], dim=0)
# layer num
self.layers += 1
# record each layer
self.history.append((sym_lowbound, sym_upbound, low_bound, up_bound, self.is_relu))
return self

def compute_verify_result(self, true_label):
self.save()
n = self.sym_lowbound.shape[0] - 1 # why
unit = torch.diag(torch.ones(n))
weights = torch.cat((-unit[:, :true_label], torch.ones(n, 1), -unit[:, true_label:], torch.zeros(n, 1)), dim=1)

for i in range(self.layers, 0, -1):
weights = self.resolve(weights, i - 1, lower=True)

return weights

################################################################## to understand
def resolve(self, constrains, layer, lower=True):
"""
lower = True: return the lower bound
lower = False: return the upper bound
"""
# distinguish the sign of the coefficients of the constraints
pos_coeff = F.relu(constrains)
neg_coeff = F.relu(-constrains)
layer_info = self.history[layer]
is_relu = layer_info[-1]
if layer == 0:
# layer_info[2],layer_info[3]: concrete lower and upper bound
low_bound, up_bound = layer_info[2], layer_info[3]
else:
# layer_info[0],layer_info[1]: symbolic lower and upper bound
low_bound, up_bound = layer_info[0], layer_info[1]
if not lower:
low_bound, up_bound = up_bound, low_bound
if is_relu:
low_diag, low_bias = low_bound[0], low_bound[1]
up_diag, up_bias = up_bound[0], up_bound[1]
low_bias = torch.cat([low_bias, torch.ones(1)])
up_bias = torch.cat([up_bias, torch.ones(1)])

m1 = torch.cat([pos_coeff[:, :-1] * low_diag, torch.matmul(pos_coeff, low_bias).unsqueeze(1)], dim=1)
m2 = torch.cat([neg_coeff[:, :-1] * up_diag, torch.matmul(neg_coeff, up_bias).unsqueeze(1)], dim=1)
return m1 - m2
else:
return torch.matmul(pos_coeff, low_bound) - torch.matmul(neg_coeff, up_bound)



class DPReLU(nn.Module):
def __init__(self, size):
super(DPReLU, self).__init__()
self.in_features = size
self.out_features = size
self.alpha = torch.nn.Parameter(torch.ones(size))

def forward(self, x):
x.save()
low, up = x.lowbound, x.upbound
mask_1, mask_2 = low.ge(0), up.le(0)
mask_3 = ~(mask_1 | mask_2)
print("DPReLU: ", low.shape, up.shape, self.alpha.shape)

'''
low > 0
'''
slope_low_1 = (F.relu(up) - F.relu(low)) / (up - low)
bias_low_1 = F.relu(low) - slope_low_1 * low
slope_up_1 = (F.relu(up) - F.relu(low)) / (up - low)
bias_up_1 = F.relu(up) - slope_up_1 * up
'''
up < 0
'''
slope_low_2 = (F.relu(up) - F.relu(low)) / (up - low)
bias_low_2 = F.relu(low) - slope_low_2 * low
slope_up_2 = F.relu(up) - F.relu(low) / (up - low)
bias_up_2 = F.relu(up) - slope_up_2 * up
'''
low < 0 < up
'''
slope_low_3 = self.alpha
bias_low_3 = self.alpha * low - self.alpha * low
slope_up_3 = (F.relu(up) - F.relu(low)) / (up - low)
bias_up_3 = F.relu(up) - slope_up_3 * up

curr_slb = slope_low_1 * mask_1 + slope_low_2 * mask_2 + slope_low_3 * mask_3
curr_slb_bias = bias_low_1 * mask_1 + bias_low_2 * mask_2 + bias_low_3 * mask_3
curr_sub = slope_up_1 * mask_1 + slope_up_2 * mask_2 + slope_up_3 * mask_3
curr_sub_bias = bias_up_1 * mask_1 + bias_up_2 * mask_2 + bias_up_3 * mask_3

x.lowbound = F.relu(low)
x.upbound = F.relu(up)
x.sym_lowbound = torch.cat([curr_slb.unsqueeze(0), curr_slb_bias.unsqueeze(0)], dim=0)
x.sym_upbound = torch.cat([curr_sub.unsqueeze(0), curr_sub_bias.unsqueeze(0)], dim=0)
x.is_relu = True
return x


class DPLinear(nn.Module):
def __init__(self, nested: nn.Linear):
super(DPLinear, self).__init__()
self.weight = nested.weight.detach()
self.bias = nested.bias.detach()
self.in_features = nested.in_features
self.out_features = nested.out_features

def forward(self, x):
x.save()
# append bias as last column
init_slb = torch.cat([self.weight, self.bias.unsqueeze(1)], dim=1)
x.lb = init_slb
x.ub = init_slb
x.slb = init_slb
x.sub = init_slb
for i in range(x.layers, 0, -1):
x.lb = x.resolve(x.lb, i - 1, lower=True)
x.ub = x.resolve(x.ub, i - 1, lower=False)
x.is_relu = False
return x


48 changes: 46 additions & 2 deletions code/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
import csv
import torch
import torch.nn.functional as F
from networks import get_network, get_net_name, NormalizedResnet
import torch.nn as nn
from torch import optim
from networks import get_network, get_net_name, NormalizedResnet, FullyConnected, Normalization
from deeppoly import DeepPoly, DPReLU, DPLinear


DEVICE = 'cpu'
DTYPE = torch.float32
LR = 0.05
num_iter = 1000

def transform_image(pixel_values, input_dim):
normalized_pixel_values = torch.tensor([float(p) / 255.0 for p in pixel_values])
Expand Down Expand Up @@ -48,8 +53,46 @@ def get_net(net, net_name):
return net


# convert networks to verifiable networks
def verifiable(net, pixels):
layers = [module for module in net.modules() if type(module) not in [FullyConnected, nn.Sequential]]
verifiable_net = []

for layer in layers:
if type(layer) == nn.ReLU:
if len(verifiable_net) == 0:
verifiable_net.append(DPReLU(len(pixels)))
else:
verifiable_net.append(DPReLU(verifiable_net[-1].out_features))
elif type(layer) == nn.Linear:
verifiable_net.append(DPLinear(layer))

return nn.Sequential(*verifiable_net)


def nomalize(value, inputs):
if inputs.shape == torch.Size([1, 1, 28, 28]):
norm = Normalization(DEVICE, 'mnist')
else:
norm = Normalization(DEVICE, 'cifar10')
return norm(value).view(-1)

def analyze(net, inputs, eps, true_label):
return 0
low_bound = nomalize((inputs - eps).clamp(0, 1), inputs)
up_bound = nomalize((inputs + eps).clamp(0, 1), inputs)
verifiable_net = verifiable(net, inputs)
optimizer = optim.Adam(verifiable_net.parameters(), lr=LR)
for i in range(num_iter):
optimizer.zero_grad()
verifier_output = verifiable_net(DeepPoly(low_bound.shape[0], low_bound, up_bound))
res = verifier_output.compute_verify_result(true_label)
if (res > 0).all():
return True
loss = torch.log(-res[res < 0]).max()
loss.backward()
optimizer.step()

return False


def main():
Expand All @@ -63,6 +106,7 @@ def main():

inputs, true_label, eps = get_spec(args.spec, dataset)
net = get_net(args.net, net_name)
print("net: ", net)

outs = net(inputs)
pred_label = outs.max(dim=1)[1].item()
Expand Down
18 changes: 18 additions & 0 deletions project_log
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
实现给ReLU加上DeepPoly的函数(类)
使用上述函数(类)将各网络转化为可验证的网络
使用可验证的网络验证各网络
=========================================


已完成:


------------------------------------------------
要做的:
在deeppoly.py中实现给ReLU加上DeepPoly的函数(类)
- linear和conv的神经元也需要加上DeepPoly吗?
在verifier.py中使用上述函数(类)将各网络转化为可验证的网络
- FullyConnected
- Conv
- NormalizedResnet
在verifier.py中使用可验证的网络验证各网络

0 comments on commit 3297c50

Please sign in to comment.