-
Notifications
You must be signed in to change notification settings - Fork 0
/
maxSens.jl
84 lines (74 loc) · 2.82 KB
/
maxSens.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
MaxSens(resolution::Float64, tight::Bool)
MaxSens performs over-approximated reachability analysis to compute the over-approximated output reachable set for a network.
# Problem requirement
1. Network: any depth, any activation that is monotone
2. Input: `Hyperrectangle` or `HPolytope`
3. Output: `AbstractPolytope`
# Return
`ReachabilityResult`
# Method
First partition the input space into small grid cells according to `resolution`.
Then use interval arithmetic to compute the reachable set for each cell.
Two versions of interval arithmetic is implemented with indicator `tight`.
Default `resolution` is `1.0`. Default `tight = false`.
# Property
Sound but not complete.
# Reference
[W. Xiang, H.-D. Tran, and T. T. Johnson,
"Output Reachable Set Estimation and Verification for Multi-Layer Neural Networks,"
*ArXiv Preprint ArXiv:1708.03322*, 2017.](https://arxiv.org/abs/1708.03322)
"""
@with_kw struct MaxSens
resolution::Float64 = 1.0
tight::Bool = false
end
# This is the main function
function solve(solver::MaxSens, problem::Problem) #original
result = true
delta = solver.resolution
lower, upper = low(problem.input), high(problem.input)
n_hypers_per_dim = BigInt.(max.(ceil.(Int, (upper-lower) / delta), 1))
# preallocate work arrays
local_lower, local_upper, CI = similar(lower), similar(lower), similar(lower)
for i in 1:prod(n_hypers_per_dim)
n = i
for j in firstindex(CI):lastindex(CI)
n, CI[j] = fldmod1(n, n_hypers_per_dim[j])
end
@. local_lower = lower + delta * (CI - 1)
@. local_upper = min(local_lower + delta, upper)
hyper = Hyperrectangle(low = local_lower, high = local_upper)
reach = forward_network(solver, problem.network, hyper)
if !issubset(reach, problem.output)
result = false
end
end
if result
return BasicResult(:holds)
end
return BasicResult(:violated)
end
# This function is called by forward_network
function forward_layer(solver::MaxSens, L::Layer, input::Hyperrectangle)
(W, b, act) = (L.weights, L.bias, L.activation)
center = zeros(size(W, 1))
gamma = zeros(size(W, 1))
for j in 1:size(W, 1)
node = Node(W[j,:], b[j], act)
center[j], gamma[j] = forward_node(solver, node, input)
end
return Hyperrectangle(center, gamma)
end
function forward_node(solver::MaxSens, node::Node, input::Hyperrectangle)
output = node.w' * input.center + node.b
deviation = sum(abs.(node.w) .* input.radius)
β = node.act(output) # TODO expert suggestion for variable name. beta? β? O? x?
βmax = node.act(output + deviation)
βmin = node.act(output - deviation)
if solver.tight
return ((βmax + βmin)/2, (βmax - βmin)/2)
else
return (β, max(abs(βmax - β), abs(βmin - β)))
end
end