Skip to content

Commit

Permalink
update DPBasicBlocks and timer
Browse files Browse the repository at this point in the history
  • Loading branch information
MekAkUActOR committed Nov 28, 2022
1 parent 301fe16 commit 89658b9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
11 changes: 8 additions & 3 deletions code/deeppoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,14 @@ def __init__(self, nested: resnet.BasicBlock, in_feature):
self.out_features = output_shape[0] * output_shape[1] * output_shape[2]

def forward(self, x):
lb = []
up = []
lbs = []
ubs = []
slbs = []
subs = []
for path in self.paths:
in_feature = self.in_feature
slb = 1
sub = 1
for layer in path:
if type(layer) == nn.Conv2d:
dp_conv = DPConv(layer, in_feature)
Expand All @@ -401,7 +405,8 @@ def forward(self, x):
dp_relu = DPReLU(in_feature)
dp_relu(x)
in_feature = dp_relu.out_features
lb
lbs.append(x.lb)
ubs.append(x.ub)

return x

2 changes: 1 addition & 1 deletion code/evaluate
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

rm $1/res.txt
for net in {1..10}
for net in {1..7}
do
echo Evaluating network net${net}...
for spec in `ls $1/net${net}/`
Expand Down
15 changes: 11 additions & 4 deletions code/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from resnet import ResNet, BasicBlock
from deeppoly import DeepPoly, DPReLU, DPLinear, DPConv, DPBasicBlock

import time


DEVICE = 'cpu'
DTYPE = torch.float32
Expand Down Expand Up @@ -304,6 +306,7 @@ def analyze(net, inputs, eps, true_label):


def main():
start = time.perf_counter_ns()
parser = argparse.ArgumentParser(description='Neural network verification using DeepPoly relaxation')
parser.add_argument('--net', type=str, required=True, help='Neural network architecture to be verified.')
parser.add_argument('--spec', type=str, required=True, help='Test case to verify.')
Expand All @@ -315,6 +318,7 @@ def main():
inputs, true_label, eps = get_spec(args.spec, dataset)
net = get_net(args.net, net_name)
# print([module for module in net.modules()])
'''
if type(net) == NormalizedResnet:
print(net)
for module in net.modules():
Expand All @@ -330,16 +334,19 @@ def main():
if type(mo) == nn.Sequential:
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
print(mo)
'''

outs = net(inputs)
pred_label = outs.max(dim=1)[1].item()
assert pred_label == true_label
'''
# '''
if analyze(net, inputs, eps, true_label):
print('verified')
end = time.perf_counter_ns()
print('', round(end - start)*0.000000001, 's, verified')
else:
print('not verified')
'''
end = time.perf_counter_ns()
print('', round(end - start)*0.000000001, 's, not verified')
# '''

if __name__ == '__main__':
main()

0 comments on commit 89658b9

Please sign in to comment.