Skip to content

Commit

Permalink
DeepPoly transformer for fc, some problem with DPReLU
Browse files Browse the repository at this point in the history
  • Loading branch information
Haowe committed Nov 21, 2022
1 parent 3297c50 commit 67d6103
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 43 deletions.
80 changes: 40 additions & 40 deletions code/deeppoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,38 @@


class DeepPoly:
def __init__(self, size, low_bound, up_bound):
self.lowbound = low_bound
self.upbound = up_bound
self.sym_lowbound = torch.cat([torch.diag(torch.ones(size)), torch.zeros(size).unsqueeze(1)], dim=1)
self.sym_upbound = self.sym_lowbound
def __init__(self, size, lb, ub):
self.slb = torch.cat([torch.diag(torch.ones(size)), torch.zeros(size).unsqueeze(1)], dim=1)
self.sub = self.slb
self.lb = lb
self.ub = ub
self.history = []
self.layers = 0
self.is_relu = False

def save(self):
""" Save constraints for the backsubstitution """
low_bound = torch.cat([self.lowbound, torch.ones(1)])
up_bound = torch.cat([self.upbound, torch.ones(1)])
""" Save all constraints for the back substitution """
lb = torch.cat([self.lb, torch.ones(1)])
ub = torch.cat([self.ub, torch.ones(1)])
if self.is_relu:
sym_lowbound = self.sym_lowbound
sym_upbound = self.sym_upbound
# relu layer
slb = self.slb
sub = self.sub
else:
# other layers
keep_bias = torch.zeros(1, self.sym_lowbound.shape[1])
keep_bias[0, self.sym_lowbound.shape[1] - 1] = 1
sym_lowbound = torch.cat([self.sym_lowbound, keep_bias], dim=0)
sym_upbound = torch.cat([self.sym_upbound, keep_bias], dim=0)
# layer num
keep_bias = torch.zeros(1, self.slb.shape[1])
keep_bias[0, self.slb.shape[1] - 1] = 1
slb = torch.cat([self.slb, keep_bias], dim=0)
sub = torch.cat([self.sub, keep_bias], dim=0)
# layer num
self.layers += 1
# record each layer
self.history.append((sym_lowbound, sym_upbound, low_bound, up_bound, self.is_relu))
self.history.append((slb, sub, lb, ub, self.is_relu))
return self

def compute_verify_result(self, true_label):
self.save()
n = self.sym_lowbound.shape[0] - 1 # why
n = self.slb.shape[0] - 1
unit = torch.diag(torch.ones(n))
weights = torch.cat((-unit[:, :true_label], torch.ones(n, 1), -unit[:, true_label:], torch.zeros(n, 1)), dim=1)

Expand All @@ -56,39 +57,38 @@ def resolve(self, constrains, layer, lower=True):
is_relu = layer_info[-1]
if layer == 0:
# layer_info[2],layer_info[3]: concrete lower and upper bound
low_bound, up_bound = layer_info[2], layer_info[3]
lb, ub = layer_info[2], layer_info[3]
else:
# layer_info[0],layer_info[1]: symbolic lower and upper bound
low_bound, up_bound = layer_info[0], layer_info[1]
lb, ub = layer_info[0], layer_info[1]
if not lower:
low_bound, up_bound = up_bound, low_bound
lb, ub = ub, lb
if is_relu:
low_diag, low_bias = low_bound[0], low_bound[1]
up_diag, up_bias = up_bound[0], up_bound[1]
low_bias = torch.cat([low_bias, torch.ones(1)])
up_bias = torch.cat([up_bias, torch.ones(1)])
lb_diag, lb_bias = lb[0], lb[1]
ub_diag, ub_bias = ub[0], ub[1]
lb_bias = torch.cat([lb_bias, torch.ones(1)])
ub_bias = torch.cat([ub_bias, torch.ones(1)])

m1 = torch.cat([pos_coeff[:, :-1] * low_diag, torch.matmul(pos_coeff, low_bias).unsqueeze(1)], dim=1)
m2 = torch.cat([neg_coeff[:, :-1] * up_diag, torch.matmul(neg_coeff, up_bias).unsqueeze(1)], dim=1)
m1 = torch.cat([pos_coeff[:, :-1] * lb_diag, torch.matmul(pos_coeff, lb_bias).unsqueeze(1)], dim=1)
m2 = torch.cat([neg_coeff[:, :-1] * ub_diag, torch.matmul(neg_coeff, ub_bias).unsqueeze(1)], dim=1)
return m1 - m2
else:
return torch.matmul(pos_coeff, low_bound) - torch.matmul(neg_coeff, up_bound)
return torch.matmul(pos_coeff, lb) - torch.matmul(neg_coeff, ub)



class DPReLU(nn.Module):
def __init__(self, size):
def __init__(self, in_features):
super(DPReLU, self).__init__()
self.in_features = size
self.out_features = size
self.alpha = torch.nn.Parameter(torch.ones(size))
self.in_features = in_features
self.out_features = in_features
self.alpha = torch.nn.Parameter(torch.ones(in_features) * 0.5)

def forward(self, x):
x.save()
low, up = x.lowbound, x.upbound
low, up = x.lb, x.ub
mask_1, mask_2 = low.ge(0), up.le(0)
mask_3 = ~(mask_1 | mask_2)
print("DPReLU: ", low.shape, up.shape, self.alpha.shape)

'''
low > 0
Expand All @@ -107,8 +107,8 @@ def forward(self, x):
'''
low < 0 < up
'''
slope_low_3 = self.alpha
bias_low_3 = self.alpha * low - self.alpha * low
slope_low_3 = torch.tan(self.alpha)
bias_low_3 = slope_low_3 * low - slope_low_3 * low
slope_up_3 = (F.relu(up) - F.relu(low)) / (up - low)
bias_up_3 = F.relu(up) - slope_up_3 * up

Expand All @@ -117,17 +117,17 @@ def forward(self, x):
curr_sub = slope_up_1 * mask_1 + slope_up_2 * mask_2 + slope_up_3 * mask_3
curr_sub_bias = bias_up_1 * mask_1 + bias_up_2 * mask_2 + bias_up_3 * mask_3

x.lowbound = F.relu(low)
x.upbound = F.relu(up)
x.sym_lowbound = torch.cat([curr_slb.unsqueeze(0), curr_slb_bias.unsqueeze(0)], dim=0)
x.sym_upbound = torch.cat([curr_sub.unsqueeze(0), curr_sub_bias.unsqueeze(0)], dim=0)
x.lb = F.relu(low)
x.ub = F.relu(up)
x.slb = torch.cat([curr_slb.unsqueeze(0), curr_slb_bias.unsqueeze(0)], dim=0)
x.sub = torch.cat([curr_sub.unsqueeze(0), curr_sub_bias.unsqueeze(0)], dim=0)
x.is_relu = True
return x


class DPLinear(nn.Module):
def __init__(self, nested: nn.Linear):
super(DPLinear, self).__init__()
super().__init__()
self.weight = nested.weight.detach()
self.bias = nested.bias.detach()
self.in_features = nested.in_features
Expand Down
13 changes: 10 additions & 3 deletions code/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

DEVICE = 'cpu'
DTYPE = torch.float32
LR = 0.05
LR = 0.005
num_iter = 1000

def transform_image(pixel_values, input_dim):
Expand Down Expand Up @@ -73,9 +73,11 @@ def verifiable(net, pixels):
def nomalize(value, inputs):
if inputs.shape == torch.Size([1, 1, 28, 28]):
norm = Normalization(DEVICE, 'mnist')
norm_val = norm(value).view(-1)
else:
norm = Normalization(DEVICE, 'cifar10')
return norm(value).view(-1)
norm_val = norm(value).reshape(-1)
return norm_val

def analyze(net, inputs, eps, true_label):
low_bound = nomalize((inputs - eps).clamp(0, 1), inputs)
Expand All @@ -92,7 +94,12 @@ def analyze(net, inputs, eps, true_label):
loss.backward()
optimizer.step()

return False
verifier_output = verifiable_net(DeepPoly(low_bound.shape[0], low_bound, up_bound))
res = verifier_output.compute_verify_result(true_label)
if (res > 0).all():
return True
else:
return False


def main():
Expand Down

0 comments on commit 67d6103

Please sign in to comment.