-
-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Rootfinding in Callbacks #350
Conversation
src/callbacks.jl
Outdated
@@ -503,6 +503,10 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte | |||
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") | |||
end | |||
Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might as well put this in the loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set the atol and rtol to 0
src/callbacks.jl
Outdated
@@ -503,6 +503,10 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte | |||
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") | |||
end | |||
Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100)) | |||
sign_bottom_θ = sign(zero_func(bottom_θ)) | |||
while sign(zero_func(Θ)) != sign_bottom_θ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
going 1 by 1 can be really slow. It would be better to just put the tolerance to like 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by tolerance? Just discard all the callback tolerances and put atol = 0
in find_zero
call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's no good, still the same error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how many times is it looping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just once.
Am I really right about this? Because the tests fail when I limit the loop. That means tests used to pass without the loop. It may also mean that our tests were not good enough. |
src/callbacks.jl
Outdated
@@ -435,7 +435,7 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) | |||
end | |||
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") | |||
end | |||
Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100)) | |||
Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and rtol?
It looks like when you need zero tolerances, Roots.jl sidesteps the algorithm choice: Let's set the 4 tolerances to zero and remove the algorithm choice and see what we get. I would be surprised if this causes a major performance difference, since most of the condition calls are likely due to checking if there is a root, not the few times that a root is found. However, this should make it a lot more robust and guarantee a 1 floating point number difference from the zero. |
I guess it worked. I will compare the ncondition and see if there is any noticable regression. And running OrdinaryDiffEq tests now. |
I don't know if we need the loop, since that should make it so prevfloat is always one to the left. Maybe have it only possibly prevfloat once, and otherwise error? Since we will want feedback on how well this actually works, and that shouldn't increase the cost in any noticeable way to do one more condition call. |
I just bumped a
On this branch:
Only the fourth call, bouncing ball example, becomes double in count. |
It's double because the bouncing ball example is such a simple ODE that it literally has an event every single step. We saw that in this paper: https://www.sciencedirect.com/science/article/abs/pii/S0965997818310251 I think it's safe to say that's an edge case, where if you literally have a root every single step then using the more robust method uses twice as many condition calls, where the condition is trivial 🤷♂ . We're probably fine. |
Okay. For the looping, I guess single prevfloat should be enough, otherwise error. Glad this error was caught, but didn't expected this error to show up in such simple problem. |
this is the beauty of having users who write test cases which break when floating point is exact. Being off by 1 floating point number is probably the reason for half of our event handling or adaptive time stepping tests 🤣. Let's make sure we get this test case added as a downstream test. |
Downstream tests are now failing. |
That means the rootfinding still isn't accurate enough to control a bouncing ball 🤣 I guess we should have some loop then. 10 times is good enough. I think that we should have our own rootfinding that works on finding a number less than the actual root. |
wait, how far off is it? The docs there literally say it should be "exact" 🤷♂ |
This reverts commit 8f1a03b.
After adding the loop, I added a |
The PSOS event test mostly gives 0 prevfloat_idx but many times it is a 4 digit integer. |
Wait, with PSOS the bisection halts very early? |
It should be "impossible" for it to be that far away with Bisection. That would be worth to build into an MWE. |
@kanav99 where did we leave this? |
The issue was that the result of root was not consistent in terms of sign. We need something which either gives us sign of the lower limit of exact Root or the upper limit. Roots.jl doesn't give me this consistency. I saw your mention and yeah, if we get an interval instead of exact value, that should solve our issue. |
@@ -555,10 +581,11 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter) | |||
iter = 1 | |||
while sign(zero_func(bottom_θ)) == sign_top && iter < 12 | |||
bottom_θ *= 5 | |||
iter += 1 | |||
end | |||
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we still need to do all of this, or if we just immediately throw an error if we see this.
I guess scope of the PR is complete. Some AD issues need to be addressed. I will open a seperate issue for that probably. |
Let's get some more AD tests before committing to it. That's a pretty core feature that we can't break. |
Ref: SciML/DifferentialEquations.jl#516 (comment)
Fixes SciML/DifferentialEquations.jl#516 , Fixes SciML/DifferentialEquations.jl#551 , Fixes bifurcationkit/BifurcationKit.jl#9