Skip to content

Commit

Permalink
update learning method and some dpbasicblocks
Browse files Browse the repository at this point in the history
  • Loading branch information
MekAkUActOR committed Dec 14, 2022
1 parent d25b11c commit ffeffa7
Showing 1 changed file with 28 additions and 50 deletions.
78 changes: 28 additions & 50 deletions code/deeppoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,13 @@ def __init__(self, size, lb, ub):
self.lb = lb
self.ub = ub

self.history = []
self.history = [[]]
self.layers = 0

def save(self):
# print("save ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
def save(self, path=0):
# """ Save all constraints for the back substitution """
lb = torch.cat([self.lb, torch.ones(1)])
# print("save lb,", lb.shape)
ub = torch.cat([self.ub, torch.ones(1)])
# print("save ub,", ub.shape)
keep_bias = torch.zeros(1, self.slb.shape[1])
keep_bias[0, self.slb.shape[1] - 1] = 1
slb = torch.cat([self.slb, keep_bias], dim=0)
Expand All @@ -29,26 +26,22 @@ def save(self):
# layer num
self.layers += 1
# record each layer
self.history.append((slb, sub, lb, ub))
self.history[path].append((slb, sub, lb, ub))
return self

def compute_verify_result(self, true_label):
self.save()
n = self.slb.shape[0] - 1
unit = torch.diag(torch.ones(n))
weights = torch.cat((-unit[:, :true_label], torch.ones(n, 1), -unit[:, true_label:], torch.zeros(n, 1)), dim=1)
# print("compute_verify_result,", weights.shape)
# print("===========================================================================")

for i in range(self.layers, 0, -1):
weights = self.resolve(weights, i - 1, lower=True)
# print("----------------------------------------------------------")
# print("compute_verify_result,", weights.shape, weights)

return weights

# TODO: implement Conv in resolve
def resolve(self, constrains, layer, lower=True):
def resolve(self, constrains, layer, lower=True, path=0):
# print("resolve >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
"""
lower = True: return the lower bound
Expand All @@ -57,7 +50,7 @@ def resolve(self, constrains, layer, lower=True):
# distinguish the sign of the coefficients of the constraints
pos_coeff = F.relu(constrains)
neg_coeff = -F.relu(-constrains)
layer_info = self.history[layer]
layer_info = self.history[path][layer]
if layer == 0:
# layer_info[2],layer_info[3]: concrete lower and upper bound
lb, ub = layer_info[2], layer_info[3]
Expand Down Expand Up @@ -169,18 +162,14 @@ def __init__(self, nested: nn.Linear):
def forward(self, x):
x.save()
# append bias as last column
# print("Linear weight", self.weight.shape)
# print("Linear bias", self.bias.shape)
init_slb = torch.cat([self.weight, self.bias.unsqueeze(1)], dim=1)
# print("Linear init_slb", init_slb.shape)
x.lb = init_slb
x.ub = init_slb
x.slb = init_slb
x.sub = init_slb
for i in range(x.layers, 0, -1):
x.lb = x.resolve(x.lb, i - 1, lower=True)
x.ub = x.resolve(x.ub, i - 1, lower=False)
# x.is_relu = False
return x


Expand All @@ -197,19 +186,12 @@ def __init__(self, nested: nn.Conv2d, in_feature):
# self.padding_mode = nested.padding_mode
# self.dilation = nested.dilation
# self.groups = nested.groups
# print("DPConv create", in_feature)
# print("DPConv create", self.in_channels)
# print("DPConv create", self.out_channels)
img_height = math.floor(math.sqrt(in_feature / self.in_channels))
self.in_features = (
self.in_channels,
img_height,
img_height,
)
# print("DPConv create", self.in_features)
# print("DPConv create", self.kernel_size)
# print("DPConv create", self.stride)
# print("DPConv create", self.padding)
self.out_features = self.out_channels * \
math.floor((self.in_features[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1) * \
math.floor((self.in_features[2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1)
Expand All @@ -218,17 +200,13 @@ def __init__(self, nested: nn.Conv2d, in_feature):
def forward(self, x):
x.save()
init_slb = self.weightMtrx
# print("Conv init_slb", init_slb.shape, init_slb)
x.lb = init_slb
x.ub = init_slb
x.slb = init_slb
x.sub = init_slb
for i in range(x.layers, 0, -1):
x.lb = x.resolve(x.lb, i - 1, lower=True)
x.ub = x.resolve(x.ub, i - 1, lower=False)
# print("Conv init_slb x.lb", x.lb)
# print("Conv init_slb x.ub", x.ub)
# x.is_relu = False
return x

@staticmethod
Expand All @@ -239,8 +217,8 @@ def get_sparse_conv_weights_matrix(conv_weights, in_feature, bias, stride=1, pad
inputs: [(pixels: c_in * input_height * input_width)]
bias: [(c_out)]
"""
c_out = conv_weights.shape[0]
c_in = conv_weights.shape[1]
# c_out = conv_weights.shape[0]
# c_in = conv_weights.shape[1]
kernel_size = conv_weights.shape[2]
input_height = in_feature[1]
input_width = in_feature[2]
Expand All @@ -260,8 +238,6 @@ def get_sparse_conv_weights_matrix(conv_weights, in_feature, bias, stride=1, pad
input_pad_h,
input_pad_w,
stride)
# a test
# kernel_matrix = torch.randn((output_height * output_width, input_pad_h * input_pad_w))
dense_matrix = dense_weights(kernel_matrix, matrix_mask)
matrix_line_lst.append(dense_matrix)
line_matrix = torch.cat(matrix_line_lst, dim=1)
Expand Down Expand Up @@ -367,25 +343,27 @@ def __init__(self, nested: resnet.BasicBlock, in_feature):
self.out_features = output_shape[0] * output_shape[1] * output_shape[2]

def forward(self, x):
lbs = []
ubs = []
slbs = []
subs = []
for path in self.paths:
in_feature = self.in_feature
slb = 1
sub = 1
for layer in path:
if type(layer) == nn.Conv2d:
dp_conv = DPConv(layer, in_feature)
dp_conv(x)
in_feature = dp_conv.out_features
elif type(layer) == nn.ReLU:
dp_relu = DPReLU(in_feature)
dp_relu(x)
in_feature = dp_relu.out_features
lbs.append(x.lb)
ubs.append(x.ub)
x.save()
x.history = x.history + x.history
# lbs = []
# ubs = []
# slbs = []
# subs = []
# for path in self.paths:
# in_feature = self.in_feature
# slb = 1
# sub = 1
# for layer in path:
# if type(layer) == nn.Conv2d:
# dp_conv = DPConv(layer, in_feature)
# dp_conv(x)
# in_feature = dp_conv.out_features
# elif type(layer) == nn.ReLU:
# dp_relu = DPReLU(in_feature)
# dp_relu(x)
# in_feature = dp_relu.out_features
# lbs.append(x.lb)
# ubs.append(x.ub)

return x

0 comments on commit ffeffa7

Please sign in to comment.