Skip to content

Commit

Permalink
update DPBasicBlocks
Browse files Browse the repository at this point in the history
  • Loading branch information
MekAkUActOR committed Nov 28, 2022
1 parent d56f49b commit 301fe16
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 22 deletions.
66 changes: 63 additions & 3 deletions code/deeppoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
import math
import resnet

class DeepPoly:
def __init__(self, size, lb, ub):
Expand Down Expand Up @@ -208,7 +209,7 @@ def forward(self, x):


class DPConv(nn.Module):
def __init__(self, nested: nn.Conv2d, in_feature):
def __init__(self, nested: nn.Conv2d, in_feature):
super(DPConv, self).__init__()
self.weight = nested.weight.detach()
self.bias = nested.bias.detach() # tensor[out_channels]
Expand All @@ -223,10 +224,11 @@ def __init__(self, nested: nn.Conv2d, in_feature):
# print("DPConv create", in_feature)
# print("DPConv create", self.in_channels)
# print("DPConv create", self.out_channels)
img_height = math.floor(math.sqrt(in_feature / self.in_channels))
self.in_features = (
self.in_channels,
math.floor(math.sqrt(in_feature / self.in_channels)),
math.floor(math.sqrt(in_feature / self.in_channels)),
img_height,
img_height,
)
# print("DPConv create", self.in_features)
# print("DPConv create", self.kernel_size)
Expand Down Expand Up @@ -345,3 +347,61 @@ def pad_image(temp_input, padding):
# std = torch.std(x, unbiased=False)
# one_channel_batchnorm = (x - mean) / std
# return one_channel_batchnorm

# if __name__ == "__main__":
# test = get_network('cpu', 'net4')
# print(test)
# # Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
# # Flatten(start_dim=1, end_dim=-1)
# a = DPConv2d(torch.nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), (1, 28, 28))
# # print(a.weight)
# print(a.bias)
# print(type(a.bias))
# print(a.bias.shape)
#
# print(a.weight.shape)
# b = torch.cat([a.weight, a.bias.unsqueeze(3)], dim=0)
# print(b.shape)

class DPBasicBlock(nn.Module):
def __init__(self, nested: resnet.BasicBlock, in_feature):
super(DPBasicBlock, self).__init__()
self.in_planes = nested.in_planes
self.planes = nested.planes
self.stride = nested.stride
self.bn = nested.bn
self.kernel = nested.kernel
self.expansion = nested.expansion
self.paths = []
for modu in nested.modules():
if type(modu) == nn.Sequential:
self.paths.append(modu)
self.in_feature = in_feature
img_height = math.floor(math.sqrt(in_feature / self.in_planes))
self.in_features = (
self.in_planes,
img_height,
img_height,
)
temp_img = torch.zeros(self.in_features)
output_shape = self.paths[0](temp_img).shape
self.out_features = output_shape[0] * output_shape[1] * output_shape[2]

def forward(self, x):
lb = []
up = []
for path in self.paths:
in_feature = self.in_feature
for layer in path:
if type(layer) == nn.Conv2d:
dp_conv = DPConv(layer, in_feature)
dp_conv(x)
in_feature = dp_conv.out_features
elif type(layer) == nn.ReLU:
dp_relu = DPReLU(in_feature)
dp_relu(x)
in_feature = dp_relu.out_features
lb

return x

92 changes: 73 additions & 19 deletions code/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import torch.nn.functional as F
import torch.nn as nn
from torch import optim
from networks import get_network, get_net_name, NormalizedResnet, FullyConnected, Normalization
from deeppoly import DeepPoly, DPReLU, DPLinear, DPConv
from networks import get_network, get_net_name, NormalizedResnet, FullyConnected, Conv, Normalization
from resnet import ResNet, BasicBlock
from deeppoly import DeepPoly, DPReLU, DPLinear, DPConv, DPBasicBlock


DEVICE = 'cpu'
Expand Down Expand Up @@ -179,22 +180,60 @@ def get_net(net, net_name):

# convert networks to verifiable networks
def verifiable(net, pixels):
layers = [module for module in net.modules() if type(module) not in [FullyConnected, nn.Sequential]]
verifiable_net = []
if type(net) == NormalizedResnet:
layers = [module for module in net.modules() if type(module) is ResNet]
for modu in layers:
for layer in modu:
if type(layer) == nn.ReLU:
if len(verifiable_net) == 0:
verifiable_net.append(DPReLU(len(pixels)))
else:
verifiable_net.append(DPReLU(verifiable_net[-1].out_features))
elif type(layer) == nn.Linear:
verifiable_net.append(DPLinear(layer))
elif type(layer) == nn.Conv2d:
if len(verifiable_net) == 0:
verifiable_net.append(DPConv(layer, len(pixels)))
else:
verifiable_net.append(DPConv(layer, verifiable_net[-1].out_features))
elif type(layer) == nn.Sequential:
for lay in layer:
if type(lay) == nn.ReLU:
if len(verifiable_net) == 0:
verifiable_net.append(DPReLU(len(pixels)))
else:
verifiable_net.append(DPReLU(verifiable_net[-1].out_features))
elif type(lay) == nn.Linear:
verifiable_net.append(DPLinear(lay))
elif type(lay) == nn.Conv2d:
if len(verifiable_net) == 0:
verifiable_net.append(DPConv(lay, len(pixels)))
else:
verifiable_net.append(DPConv(lay, verifiable_net[-1].out_features))
elif type(lay) == BasicBlock:
if len(verifiable_net) == 0:
verifiable_net.append(DPBasicBlock(lay, len(pixels)))
else:
verifiable_net.append(DPBasicBlock(lay, verifiable_net[-1].out_features))

for layer in layers:
if type(layer) == nn.ReLU:
if len(verifiable_net) == 0:
verifiable_net.append(DPReLU(len(pixels)))
else:
verifiable_net.append(DPReLU(verifiable_net[-1].out_features))
elif type(layer) == nn.Linear:
verifiable_net.append(DPLinear(layer))
elif type(layer) == nn.Conv2d:
if len(verifiable_net) == 0:
verifiable_net.append(DPConv(layer, len(pixels)))
else:
verifiable_net.append(DPConv(layer, verifiable_net[-1].out_features))

else:
layers = [module for module in net.modules() if type(module) not in [FullyConnected, Conv, NormalizedResnet, nn.Sequential]]
for layer in layers:
# print(layer)
if type(layer) == nn.ReLU:
if len(verifiable_net) == 0:
verifiable_net.append(DPReLU(len(pixels)))
else:
verifiable_net.append(DPReLU(verifiable_net[-1].out_features))
elif type(layer) == nn.Linear:
verifiable_net.append(DPLinear(layer))
elif type(layer) == nn.Conv2d:
if len(verifiable_net) == 0:
verifiable_net.append(DPConv(layer, len(pixels)))
else:
verifiable_net.append(DPConv(layer, verifiable_net[-1].out_features))

return nn.Sequential(*verifiable_net)

Expand Down Expand Up @@ -275,17 +314,32 @@ def main():

inputs, true_label, eps = get_spec(args.spec, dataset)
net = get_net(args.net, net_name)
# print(net)
# print([module for module in net.modules()])
if type(net) == NormalizedResnet:
print(net)
for module in net.modules():
if type(module) == ResNet:
print("-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --")
for layer in module:
print(layer)
for layer in module:
if type(layer) == nn.Sequential:
for modu in layer:
if type(modu) == BasicBlock:
for mo in modu.modules():
if type(mo) == nn.Sequential:
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
print(mo)

outs = net(inputs)
pred_label = outs.max(dim=1)[1].item()
assert pred_label == true_label
# '''
'''
if analyze(net, inputs, eps, true_label):
print('verified')
else:
print('not verified')
# '''
'''

if __name__ == '__main__':
main()

0 comments on commit 301fe16

Please sign in to comment.