From 605dac4094651e5f560dd6621113560a906db70e Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Mon, 14 Oct 2019 01:12:57 +0530 Subject: [PATCH 01/14] recursively run prevfloat in rootfind --- src/callbacks.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/callbacks.jl b/src/callbacks.jl index fed374f90..cf718c873 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -503,6 +503,10 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100)) + sign_bottom_θ = sign(zero_func(bottom_θ)) + while sign(zero_func(Θ)) != sign_bottom_θ + Θ = prevfloat(Θ) + end if Θ < minΘ integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end From d7e766da58173fc32cd06daee76bff6ef38df384 Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Mon, 14 Oct 2019 22:02:19 +0530 Subject: [PATCH 02/14] limiting loop in rootfind --- src/callbacks.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index cf718c873..9a9f77273 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -431,10 +431,18 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) iter = 1 while sign(zero_func(bottom_θ)) == sign_top && iter < 12 bottom_θ *= 5 + iter += 1 end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100)) + sign_bottom_θ = sign(zero_func(bottom_θ)) + prevfloat_idx = 0 + while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 + Θ = prevfloat(Θ) + prevfloat_idx += 1 + end + prevfloat_idx == 10 && error("Rootfind was inaccurate. Please report the error.") integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end #Θ = prevfloat(...) @@ -504,9 +512,12 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte end Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100)) sign_bottom_θ = sign(zero_func(bottom_θ)) - while sign(zero_func(Θ)) != sign_bottom_θ + prevfloat_idx = 0 + while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 Θ = prevfloat(Θ) + prevfloat_idx += 1 end + prevfloat_idx == 10 && error("Rootfind was inaccurate. Please report the error.") if Θ < minΘ integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end From 2fdc0e0a55f1678fdff8d26e17edd79af52bd361 Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Tue, 15 Oct 2019 14:45:40 +0530 Subject: [PATCH 03/14] set atol to 0 --- src/callbacks.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 9a9f77273..d3790c48b 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -435,7 +435,7 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100)) + Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0)) sign_bottom_θ = sign(zero_func(bottom_θ)) prevfloat_idx = 0 while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 @@ -510,7 +510,7 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100)) + Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0)) sign_bottom_θ = sign(zero_func(bottom_θ)) prevfloat_idx = 0 while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 From de3bb6ecd0cd6a64f720b6b4b98dc04b15855bfa Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Tue, 15 Oct 2019 15:18:51 +0530 Subject: [PATCH 04/14] set rtol to 0 --- src/callbacks.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index d3790c48b..b44407a4b 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -435,7 +435,7 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0)) + Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0, rtol = 0)) sign_bottom_θ = sign(zero_func(bottom_θ)) prevfloat_idx = 0 while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 @@ -510,7 +510,7 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0)) + Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0, rtol = 0)) sign_bottom_θ = sign(zero_func(bottom_θ)) prevfloat_idx = 0 while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 From da47cfab043beb14687d180902c8c17b1c8c5e9e Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Tue, 15 Oct 2019 19:11:34 +0530 Subject: [PATCH 05/14] set rootfinding to best accuracy --- src/callbacks.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index b44407a4b..9ee51b9b2 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -435,7 +435,7 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0, rtol = 0)) + Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0)) sign_bottom_θ = sign(zero_func(bottom_θ)) prevfloat_idx = 0 while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 @@ -510,7 +510,7 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0, rtol = 0)) + Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0)) sign_bottom_θ = sign(zero_func(bottom_θ)) prevfloat_idx = 0 while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 From 8f1a03b6f07ab23bbe6cc5f1cdbefbdcef2b7e12 Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Tue, 15 Oct 2019 20:34:14 +0530 Subject: [PATCH 06/14] remove prevfloat loop --- src/callbacks.jl | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 9ee51b9b2..4653dd2e5 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -436,13 +436,7 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0)) - sign_bottom_θ = sign(zero_func(bottom_θ)) - prevfloat_idx = 0 - while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 - Θ = prevfloat(Θ) - prevfloat_idx += 1 - end - prevfloat_idx == 10 && error("Rootfind was inaccurate. Please report the error.") + sign(zero_func(Θ)) != sign(zero_func(bottom_θ)) && error("Rootfind was inaccurate. Please report the error.") integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end #Θ = prevfloat(...) @@ -511,13 +505,7 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0)) - sign_bottom_θ = sign(zero_func(bottom_θ)) - prevfloat_idx = 0 - while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 - Θ = prevfloat(Θ) - prevfloat_idx += 1 - end - prevfloat_idx == 10 && error("Rootfind was inaccurate. Please report the error.") + sign(zero_func(Θ)) != sign(zero_func(bottom_θ)) && error("Rootfind was inaccurate. Please report the error.") if Θ < minΘ integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end From 08f0d44c3e24bed2a8b54dbc1853f28f96252dee Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Tue, 15 Oct 2019 20:49:40 +0530 Subject: [PATCH 07/14] add test for prevfloat working --- test/downstream/event_detection_tests.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/downstream/event_detection_tests.jl b/test/downstream/event_detection_tests.jl index 2984e8835..28177853d 100644 --- a/test/downstream/event_detection_tests.jl +++ b/test/downstream/event_detection_tests.jl @@ -46,6 +46,7 @@ function condition(u,t,integrator) # Event when event_f(u,t) == 0 u[1] end function affect!(integrator) + @test integrator.u[1] >= 0 integrator.u[2] = -integrator.u[2] end cb2 = ContinuousCallback(condition,affect!) @@ -55,3 +56,24 @@ p = 9.8 prob = ODEProblem(f,u0,tspan,p) sol = solve(prob,Tsit5(),callback=cb2) @test minimum(sol') > -40 + +function vcondition!(out,u,t,integrator) + out[1] = u[1] + out[2] = u[2] +end + +function vaffect!(integrator, event_idx) + @test integrator.u[1] >= 0.0 + if event_idx == 1 + integrator.u[2] = -integrator.u[2] + else + integrator.p = 0.0 + end +end + +u0 = [50.0,0.0] +tspan = (0.0,15.0) +p = 9.8 +prob = ODEProblem(f,u0,tspan,p) +Vcb = VectorContinuousCallback(vcondition!,vaffect!, 2 , save_positions=(true,true)) +sol = solve(prob,Tsit5(), callback=Vcb) From 6033b8656a6d3eee5d19c5f5c0e9c8703c917087 Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Tue, 15 Oct 2019 22:14:11 +0530 Subject: [PATCH 08/14] Revert "remove prevfloat loop" This reverts commit 8f1a03b6f07ab23bbe6cc5f1cdbefbdcef2b7e12. --- src/callbacks.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 4653dd2e5..9ee51b9b2 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -436,7 +436,13 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0)) - sign(zero_func(Θ)) != sign(zero_func(bottom_θ)) && error("Rootfind was inaccurate. Please report the error.") + sign_bottom_θ = sign(zero_func(bottom_θ)) + prevfloat_idx = 0 + while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 + Θ = prevfloat(Θ) + prevfloat_idx += 1 + end + prevfloat_idx == 10 && error("Rootfind was inaccurate. Please report the error.") integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end #Θ = prevfloat(...) @@ -505,7 +511,13 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0)) - sign(zero_func(Θ)) != sign(zero_func(bottom_θ)) && error("Rootfind was inaccurate. Please report the error.") + sign_bottom_θ = sign(zero_func(bottom_θ)) + prevfloat_idx = 0 + while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 + Θ = prevfloat(Θ) + prevfloat_idx += 1 + end + prevfloat_idx == 10 && error("Rootfind was inaccurate. Please report the error.") if Θ < minΘ integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end From 990416aef9bc52011f124428a95202626179e768 Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Thu, 13 Feb 2020 23:24:51 +0530 Subject: [PATCH 09/14] try fixing rootfinding --- src/callbacks.jl | 44 ++++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 2126e4dc9..4c35e4c48 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -522,6 +522,32 @@ function findall_events(affect!,affect_neg!,prev_sign,next_sign) findall(x-> ((prev_sign[x] < 0 && affect! !== nothing) || (prev_sign[x] > 0 && affect_neg! !== nothing)) && prev_sign[x]*next_sign[x]<=0, keys(prev_sign)) end +function find_zero_bracket(fs, x0; kwargs...) + + x = Roots.adjust_bracket(x0) + T = eltype(x[1]) + F = Roots.callable_function(fs) + method = Roots.Bisection() + state = Roots.init_state(method, F, x) + options = Roots.init_options(method, state; kwargs...) + + # check if tolerances are exactly 0 + # iszero_tol = iszero(options.xabstol) && iszero(options.xreltol) && iszero(options.abstol) && iszero(options.reltol) + + # if iszero_tol + # if T <: FloatNN + # return Roots.find_zero(F, x, Roots.BisectionExact(); kwargs...) + # else + # return find_zero(F, x, Roots.A42(); kwargs...) + # end + # end + + find_zero(method, F, options, state, Roots.NullTracks()) + + state.xn0, state.xn1 + +end + function find_callback_time(integrator,callback::ContinuousCallback,counter) event_occurred,interp_index,Θs,prev_sign,prev_sign_index,event_idx = determine_event_occurance(integrator,callback,counter) if event_occurred @@ -559,14 +585,7 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0)) - sign_bottom_θ = sign(zero_func(bottom_θ)) - prevfloat_idx = 0 - while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 - Θ = prevfloat(Θ) - prevfloat_idx += 1 - end - prevfloat_idx == 10 && error("Rootfind was inaccurate. Please report the error.") + Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), atol = callback.abstol/100) integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end #Θ = prevfloat(...) @@ -634,14 +653,7 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0)) - sign_bottom_θ = sign(zero_func(bottom_θ)) - prevfloat_idx = 0 - while sign(zero_func(Θ)) != sign_bottom_θ && prevfloat_idx < 10 - Θ = prevfloat(Θ) - prevfloat_idx += 1 - end - prevfloat_idx == 10 && error("Rootfind was inaccurate. Please report the error.") + Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), atol = callback.abstol/100) if Θ < minΘ integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end From ffde953b11c3c265bcf19b47278f7e1f00dd370c Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Thu, 13 Feb 2020 23:48:02 +0530 Subject: [PATCH 10/14] change tolerance to 0 --- src/callbacks.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 4c35e4c48..2138def2f 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -585,7 +585,7 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), atol = callback.abstol/100) + Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0) integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end #Θ = prevfloat(...) @@ -653,7 +653,7 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), atol = callback.abstol/100) + Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0) if Θ < minΘ integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end From 6535e9e923b7a585dde7fbc839753412d0ac2182 Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Fri, 14 Feb 2020 01:06:36 +0530 Subject: [PATCH 11/14] Cleanup find_zero_bracket --- src/callbacks.jl | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 2138def2f..5f40e5794 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -522,30 +522,18 @@ function findall_events(affect!,affect_neg!,prev_sign,next_sign) findall(x-> ((prev_sign[x] < 0 && affect! !== nothing) || (prev_sign[x] > 0 && affect_neg! !== nothing)) && prev_sign[x]*next_sign[x]<=0, keys(prev_sign)) end +# Ugly implemetation of find_zero that returns the final bracket instead of upper limit +# until this gets included in Roots.jl function find_zero_bracket(fs, x0; kwargs...) x = Roots.adjust_bracket(x0) - T = eltype(x[1]) F = Roots.callable_function(fs) - method = Roots.Bisection() state = Roots.init_state(method, F, x) options = Roots.init_options(method, state; kwargs...) - # check if tolerances are exactly 0 - # iszero_tol = iszero(options.xabstol) && iszero(options.xreltol) && iszero(options.abstol) && iszero(options.reltol) - - # if iszero_tol - # if T <: FloatNN - # return Roots.find_zero(F, x, Roots.BisectionExact(); kwargs...) - # else - # return find_zero(F, x, Roots.A42(); kwargs...) - # end - # end - - find_zero(method, F, options, state, Roots.NullTracks()) + find_zero(Roots.Bisection(), F, options, state, Roots.NullTracks()) state.xn0, state.xn1 - end function find_callback_time(integrator,callback::ContinuousCallback,counter) From 7800631f9c29ac82688d9e2d7274b0e1b21a83c5 Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Fri, 14 Feb 2020 01:53:20 +0530 Subject: [PATCH 12/14] fix typo and add example --- src/callbacks.jl | 3 +- test/downstream/event_detection_tests.jl | 50 ++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 5f40e5794..a20b4b714 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -528,10 +528,11 @@ function find_zero_bracket(fs, x0; kwargs...) x = Roots.adjust_bracket(x0) F = Roots.callable_function(fs) + method = Roots.Bisection() state = Roots.init_state(method, F, x) options = Roots.init_options(method, state; kwargs...) - find_zero(Roots.Bisection(), F, options, state, Roots.NullTracks()) + find_zero(method, F, options, state, Roots.NullTracks()) state.xn0, state.xn1 end diff --git a/test/downstream/event_detection_tests.jl b/test/downstream/event_detection_tests.jl index 28177853d..edf5029f2 100644 --- a/test/downstream/event_detection_tests.jl +++ b/test/downstream/event_detection_tests.jl @@ -77,3 +77,53 @@ p = 9.8 prob = ODEProblem(f,u0,tspan,p) Vcb = VectorContinuousCallback(vcondition!,vaffect!, 2 , save_positions=(true,true)) sol = solve(prob,Tsit5(), callback=Vcb) + +f = function (du,u,p,t) + du[1] = u[2] + du[2] = -9.81 +end + +function condition(u,t,integrator) # Event when event_f(u,t) == 0 + u[1] # Event when height crosses from positive to negative +end + +function affect_neg(integrator) + @test integrator.u[1] >= 0.0 + integrator.u[2] = -0.8*integrator.u[2] +end + +cb = ContinuousCallback(condition,nothing,affect_neg! = affect_neg) + +u0 = [1.0,0.0] +tspan = (0.0, 3.0) +prob = ODEProblem(f,u0,tspan) +sol = solve(prob,Tsit5(),saveat=0.01,callback=cb) + +f! = function (du,u,p,t) + du[1] = u[2] + du[2] = -p +end + +function condition!(out,u,t,integrator) + out[1] = u[1] + out[2] = u[2] +end + +function affect!(integrator, event_idx) + if event_idx == 1 + integrator.u[2] = -integrator.u[2] + else + integrator.p = 0.0 + end +end + +u0 = [50.0,0.0] +tspan = (0.0,15.0) + +begin + p = 9.8 + prob = ODEProblem(f!,u0,tspan,p) + Vcb = VectorContinuousCallback(condition!,affect!, 2 , save_positions=(true,true)) + sol = solve(prob,Tsit5(), callback=Vcb) +end + From 8feac3f37680a201db7fd2bc322b6dcd879f4cec Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Fri, 14 Feb 2020 02:39:21 +0530 Subject: [PATCH 13/14] Change algorithm for Rootfinding --- src/callbacks.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index a20b4b714..2e9fc9a1b 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -524,11 +524,10 @@ end # Ugly implemetation of find_zero that returns the final bracket instead of upper limit # until this gets included in Roots.jl -function find_zero_bracket(fs, x0; kwargs...) +function find_zero_bracket(fs, x0, method; kwargs...) x = Roots.adjust_bracket(x0) F = Roots.callable_function(fs) - method = Roots.Bisection() state = Roots.init_state(method, F, x) options = Roots.init_options(method, state; kwargs...) @@ -574,7 +573,7 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0) + Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(); atol = 0, rtol = 0, xatol = 0, xrtol = 0) integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end #Θ = prevfloat(...) @@ -642,7 +641,7 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), atol = 0, rtol = 0, xatol = 0, xrtol = 0) + Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(); atol = 0, rtol = 0, xatol = 0, xrtol = 0) if Θ < minΘ integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) end From 8da011fbef51225d5d9b9aceac2817859b35e677 Mon Sep 17 00:00:00 2001 From: Kanav Gupta Date: Sat, 15 Feb 2020 02:19:01 +0530 Subject: [PATCH 14/14] hopefully fix the rootfinding --- src/callbacks.jl | 46 ++++++++++++++++------------------------------ 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 2e9fc9a1b..614ebdef4 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -522,20 +522,6 @@ function findall_events(affect!,affect_neg!,prev_sign,next_sign) findall(x-> ((prev_sign[x] < 0 && affect! !== nothing) || (prev_sign[x] > 0 && affect_neg! !== nothing)) && prev_sign[x]*next_sign[x]<=0, keys(prev_sign)) end -# Ugly implemetation of find_zero that returns the final bracket instead of upper limit -# until this gets included in Roots.jl -function find_zero_bracket(fs, x0, method; kwargs...) - - x = Roots.adjust_bracket(x0) - F = Roots.callable_function(fs) - state = Roots.init_state(method, F, x) - options = Roots.init_options(method, state; kwargs...) - - find_zero(method, F, options, state, Roots.NullTracks()) - - state.xn0, state.xn1 -end - function find_callback_time(integrator,callback::ContinuousCallback,counter) event_occurred,interp_index,Θs,prev_sign,prev_sign_index,event_idx = determine_event_occurance(integrator,callback,counter) if event_occurred @@ -550,31 +536,31 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) bottom_θ = typeof(integrator.t)(0) end if callback.rootfind && !isdiscrete(integrator.alg) - zero_func = (Θ) -> begin - abst = integrator.tprev+integrator.dt*Θ + zero_func = (Θnot) -> begin + abst = integrator.tprev+integrator.dt*(1-Θnot) return get_condition(integrator, callback, abst) end - if zero_func(top_Θ) == 0 + if zero_func(1-top_Θ) == 0 Θ = top_Θ else if integrator.event_last_time == counter && - abs(zero_func(bottom_θ)) < 100abs(integrator.last_event_error) && + abs(zero_func(1-bottom_θ)) < 100abs(integrator.last_event_error) && prev_sign_index == 1 # Determined that there is an event by derivative # But floating point error may make the end point negative - sign_top = sign(zero_func(top_Θ)) + sign_top = sign(zero_func(1-top_Θ)) bottom_θ += 2eps(typeof(bottom_θ)) iter = 1 - while sign(zero_func(bottom_θ)) == sign_top && iter < 12 + while sign(zero_func(1-bottom_θ)) == sign_top && iter < 12 bottom_θ *= 5 iter += 1 end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(); atol = 0, rtol = 0, xatol = 0, xrtol = 0) - integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) + Θ = 1 - find_zero(zero_func, (1-top_Θ, 1-bottom_θ), Roots.Bisection()) + integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(1-Θ),integrator.t+integrator.dt*Θ) end #Θ = prevfloat(...) # prevfloat guerentees that the new time is either 1 floating point @@ -617,33 +603,33 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte minΘ = nextfloat(top_Θ) min_event_idx = -1 for idx in event_idx - zero_func = (Θ) -> begin - abst = integrator.tprev+integrator.dt*Θ + zero_func = (Θnot) -> begin + abst = integrator.tprev+integrator.dt*(1-Θnot) return ArrayInterface.allowed_getindex(get_condition(integrator, callback, abst),idx) end - if zero_func(top_Θ) == 0 + if zero_func(1-top_Θ) == 0 Θ = top_Θ else if integrator.event_last_time == counter && (callback isa VectorContinuousCallback ? integrator.vector_event_last_time == event_idx : true) && - abs(zero_func(bottom_θ)) < 100abs(integrator.last_event_error) && + abs(zero_func(1-bottom_θ)) < 100abs(integrator.last_event_error) && prev_sign_index == 1 # Determined that there is an event by derivative # But floating point error may make the end point negative - sign_top = sign(zero_func(top_Θ)) + sign_top = sign(zero_func(1-top_Θ)) bottom_θ += 2eps(typeof(bottom_θ)) iter = 1 - while sign(zero_func(bottom_θ)) == sign_top && iter < 12 + while sign(zero_func(1-bottom_θ)) == sign_top && iter < 12 bottom_θ *= 5 iter += 1 end iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") end - Θ, _ = find_zero_bracket(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(); atol = 0, rtol = 0, xatol = 0, xrtol = 0) + Θ = 1 - find_zero(zero_func, (1-top_Θ,1-bottom_θ), Roots.Bisection()) if Θ < minΘ - integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ),integrator.t+integrator.dt*Θ) + integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(1-Θ),integrator.t+integrator.dt*Θ) end end if Θ < minΘ