Skip to content

Commit 651a545

Browse files
Merge pull request #1034 from JuliaDiffEq/beefy_ad_tests
beef up the AD tests for ContinuousCallback
2 parents e2531a9 + dc926e3 commit 651a545

File tree

1 file changed

+81
-14
lines changed

1 file changed

+81
-14
lines changed

test/interface/ad_tests.jl

+81-14
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,100 @@ function f(du,u,p,t)
66
du[2] = p[2]
77
end
88

9-
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(println("Stopped.");integrator.p[2]=zero(integrator.p[2])))
10-
function test_f(p)
11-
prob = ODEProblem(f,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
12-
integrator = init(prob,Tsit5(),abstol=1e-14,reltol=1e-14,callback=cb)
13-
step!(integrator)
14-
solve!(integrator).u[end]
9+
for x in 0:0.001:5
10+
called = false
11+
function test_f(p)
12+
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(called=true;integrator.p[2]=zero(integrator.p[2])))
13+
prob = ODEProblem(f,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
14+
integrator = init(prob,Tsit5(),abstol=1e-14,reltol=1e-14,callback=cb)
15+
step!(integrator)
16+
solve!(integrator).u[end]
17+
end
18+
p = [2.0, x]
19+
called = false
20+
findiff = Calculus.finite_difference_jacobian(test_f,p)
21+
@test called
22+
called = false
23+
fordiff = ForwardDiff.jacobian(test_f,p)
24+
@test called
25+
@test findiff fordiff
1526
end
16-
p = [2.0, 1.0]
17-
findiff = Calculus.finite_difference_jacobian(test_f,p)
18-
fordiff = ForwardDiff.jacobian(test_f,p)
19-
@test findiff fordiff
2027

2128
function f2(du,u,p,t)
22-
du[1] = -u[1]
29+
du[1] = -u[2]
2330
du[2] = p[2]
2431
end
2532

33+
for x in 2.1:0.001:5
34+
called = false
35+
function test_f2(p)
36+
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(called=true;integrator.p[2]=zero(integrator.p[2])))
37+
prob = ODEProblem(f2,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
38+
integrator = init(prob,Tsit5(),abstol=1e-12,reltol=1e-12,callback=cb)
39+
step!(integrator)
40+
solve!(integrator).u[end]
41+
end
42+
p = [2.0, x]
43+
findiff = Calculus.finite_difference_jacobian(test_f2,p)
44+
@test called
45+
called = false
46+
fordiff = ForwardDiff.jacobian(test_f2,p)
47+
@test called
48+
@test findiff fordiff
49+
end
50+
51+
#=
52+
#x = 2.0 is an interesting case
53+
54+
x = 2.0
55+
2656
function test_f2(p)
57+
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(@show(x,integrator.t);called=true;integrator.p[2]=zero(integrator.p[2])))
2758
prob = ODEProblem(f2,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
28-
integrator = init(prob,Tsit5(),abstol=1e-14,reltol=1e-14,callback=cb)
59+
integrator = init(prob,Tsit5(),abstol=1e-12,reltol=1e-12,callback=cb)
2960
step!(integrator)
3061
solve!(integrator).u[end]
3162
end
32-
p = [2.0, 1.0]
63+
64+
p = [2.0, x]
3365
findiff = Calculus.finite_difference_jacobian(test_f2,p)
66+
@test called
67+
called = false
3468
fordiff = ForwardDiff.jacobian(test_f2,p)
35-
@test findiff fordiff
69+
@test called
70+
71+
# At that value, it shouldn't be called, but a small perturbation will make it called, so finite difference is wrong!
72+
=#
73+
74+
for x in 1.0:0.001:2.5
75+
function lotka_volterra(du,u,p,t)
76+
x, y = u
77+
α, β, δ, γ = p
78+
du[1] = dx = α*x - β*x*y
79+
du[2] = dy = -δ*y + γ*x*y
80+
end
81+
u0 = [1.0,1.0]
82+
tspan = (0.0,10.0)
83+
p = [x,1.0,3.0,1.0]
84+
prob = ODEProblem(lotka_volterra,u0,tspan,p)
85+
sol = solve(prob,Tsit5())
86+
87+
called=false
88+
function test_lotka(p)
89+
cb = ContinuousCallback((u,t,i) -> u[1]-2.5, (integrator)->(called=true;integrator.p[4]=1.5))
90+
prob = ODEProblem(lotka_volterra,eltype(p).([1.0,1.0]),eltype(p).((0.0,10.0)),copy(p))
91+
integrator = init(prob,Tsit5(),abstol=1e-12,reltol=1e-12,callback=cb)
92+
step!(integrator)
93+
solve!(integrator).u[end]
94+
end
95+
96+
findiff = Calculus.finite_difference_jacobian(test_lotka,p)
97+
@test called
98+
called = false
99+
fordiff = ForwardDiff.jacobian(test_lotka,p)
100+
@test called
101+
@test findiff fordiff
102+
end
36103

37104
# Gradients and Hessians
38105

0 commit comments

Comments
 (0)