-
Notifications
You must be signed in to change notification settings - Fork 0
/
hullSearch.jl
82 lines (73 loc) · 2.33 KB
/
hullSearch.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
HullSearch(tolerance::Float64)
HullSearch performs over-approximated reachability analysis to compute the over-approximated output reachable set for a network.
# Problem requirement
1. Network: any depth, any activation that is monotone
2. Input: `Hyperrectangle`
3. Output: `HPolytope`
# Return
`BasicResult`
# Method
Search and Reachability
# Property
Sound but not complete.
"""
@with_kw struct HullSearch
tolerance::Float64 = 1.0
end
# This is the main function
function solve(solver::HullSearch, problem::Problem)
input = problem.input
stack = Vector{Hyperrectangle}(undef, 0)
push!(stack, input)
#count = 1
while !isempty(stack)
interval = popfirst!(stack)
reach = forward_network(solver, problem.network, interval)
if issubset(reach, problem.output)
continue
else
if get_largest_width(interval) > solver.tolerance
sections = bisect(interval)
for i in 1:2
if isborder(sections[i], problem.input)
push!(stack, sections[i])
#count += 1
end
end
else
return BasicResult(:violated)
end
end
end
#print("\n$(count)\n")
return BasicResult(:holds)
end
#to determine whether x has intersection with any border of y
function isborder(x::Hyperrectangle, y::Hyperrectangle)
x_lower, x_upper = low(x), high(x)
y_lower, y_upper = low(y), high(y)
for i in 1:lastindex(x_lower)
if x_lower[i] == y_lower[i] || x_upper[i] == y_upper[i]
return true
end
end
return false
end
function forward_layer(solver::HullSearch, L::Layer, input::Hyperrectangle)
(W, b, act) = (L.weights, L.bias, L.activation)
center = zeros(size(W, 1))
gamma = zeros(size(W, 1))
for j in 1:size(W, 1)
node = Node(W[j,:], b[j], act)
center[j], gamma[j] = forward_node(solver, node, input)
end
return Hyperrectangle(center, gamma)
end
function forward_node(solver::HullSearch, node::Node, input::Hyperrectangle)
output = node.w' * input.center + node.b
deviation = sum(abs.(node.w) .* input.radius)
βmax = node.act(output + deviation)
βmin = node.act(output - deviation)
return ((βmax + βmin)/2, (βmax - βmin)/2)
end