Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Rootfinding in Callbacks #350

Merged
merged 15 commits into from
Feb 15, 2020
Merged

Fix Rootfinding in Callbacks #350

merged 15 commits into from
Feb 15, 2020

Conversation

kanav99
Copy link
Contributor

@kanav99 kanav99 commented Oct 13, 2019

src/callbacks.jl Outdated
@@ -503,6 +503,10 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.")
end
Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might as well put this in the loop

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set the atol and rtol to 0

src/callbacks.jl Outdated
@@ -503,6 +503,10 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.")
end
Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100))
sign_bottom_θ = sign(zero_func(bottom_θ))
while sign(zero_func(Θ)) != sign_bottom_θ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

going 1 by 1 can be really slow. It would be better to just put the tolerance to like 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by tolerance? Just discard all the callback tolerances and put atol = 0 in find_zero call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no good, still the same error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how many times is it looping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just once.

@kanav99
Copy link
Contributor Author

kanav99 commented Oct 14, 2019

In the first callback event, the ball goes from +ve to -ve, so essentially zero_func(θ) should be positive as we prevfloat the root. But actually it's negetive. A single prevfloat is not able to pull θ back to positive. Issue is over here. But it's still a mystery why it runs fine on debugging.

Am I really right about this? Because the tests fail when I limit the loop. That means tests used to pass without the loop. It may also mean that our tests were not good enough.

src/callbacks.jl Outdated
@@ -435,7 +435,7 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter)
end
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.")
end
Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = callback.abstol/100))
Θ = prevfloat(find_zero(zero_func, (bottom_θ,top_Θ), Roots.AlefeldPotraShi(), atol = 0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and rtol?

@ChrisRackauckas
Copy link
Member

It looks like when you need zero tolerances, Roots.jl sidesteps the algorithm choice:

https://github.com/JuliaMath/Roots.jl/blob/2921a2dba2023de3f25681383009b24cde3370b2/src/bracketing.jl#L318-L324

https://github.com/JuliaMath/Roots.jl/blob/2cb8b737de06496ee45cd2b67b6eae1f57dd4532/src/find_zero.jl#L151-L192

Let's set the 4 tolerances to zero and remove the algorithm choice and see what we get. I would be surprised if this causes a major performance difference, since most of the condition calls are likely due to checking if there is a root, not the few times that a root is found. However, this should make it a lot more robust and guarantee a 1 floating point number difference from the zero.

@kanav99
Copy link
Contributor Author

kanav99 commented Oct 15, 2019

I guess it worked. I will compare the ncondition and see if there is any noticable regression. And running OrdinaryDiffEq tests now.
Also, do we need the loop now?

@ChrisRackauckas
Copy link
Member

I don't know if we need the loop, since that should make it so prevfloat is always one to the left. Maybe have it only possibly prevfloat once, and otherwise error? Since we will want feedback on how well this actually works, and that shouldn't increase the cost in any noticeable way to do one more condition call.

@kanav99
Copy link
Contributor Author

kanav99 commented Oct 15, 2019

I just bumped a println(sol.destats.ncondition) after every solve call here.
On master:

132939
177512
2227183
55409

On this branch:

138424
184492
2308394
118606

Only the fourth call, bouncing ball example, becomes double in count.

@ChrisRackauckas
Copy link
Member

It's double because the bouncing ball example is such a simple ODE that it literally has an event every single step. We saw that in this paper:

https://www.sciencedirect.com/science/article/abs/pii/S0965997818310251

I think it's safe to say that's an edge case, where if you literally have a root every single step then using the more robust method uses twice as many condition calls, where the condition is trivial 🤷‍♂ . We're probably fine.

@kanav99
Copy link
Contributor Author

kanav99 commented Oct 15, 2019

Okay. For the looping, I guess single prevfloat should be enough, otherwise error. Glad this error was caught, but didn't expected this error to show up in such simple problem.

@ChrisRackauckas
Copy link
Member

this is the beauty of having users who write test cases which break when floating point is exact. Being off by 1 floating point number is probably the reason for half of our event handling or adaptive time stepping tests 🤣.

Let's make sure we get this test case added as a downstream test.

@ChrisRackauckas
Copy link
Member

Downstream tests are now failing.

@kanav99
Copy link
Contributor Author

kanav99 commented Oct 15, 2019

That means the rootfinding still isn't accurate enough to control a bouncing ball 🤣 I guess we should have some loop then. 10 times is good enough. I think that we should have our own rootfinding that works on finding a number less than the actual root.

@ChrisRackauckas
Copy link
Member

wait, how far off is it? The docs there literally say it should be "exact" 🤷‍♂

@kanav99
Copy link
Contributor Author

kanav99 commented Oct 15, 2019

After adding the loop, I added a @show prevfloat_idx and it is mostly 0 and just 1 once. Note that we do prevfloat once outside the loop too. Its actually pretty accurate but still it fails 😛.

@kanav99
Copy link
Contributor Author

kanav99 commented Oct 15, 2019

The PSOS event test mostly gives 0 prevfloat_idx but many times it is a 4 digit integer.

@ChrisRackauckas
Copy link
Member

Wait, with PSOS the bisection halts very early?

@ChrisRackauckas
Copy link
Member

It should be "impossible" for it to be that far away with Bisection. That would be worth to build into an MWE.

@ChrisRackauckas
Copy link
Member

@kanav99 where did we leave this?

@kanav99
Copy link
Contributor Author

kanav99 commented Feb 4, 2020

The issue was that the result of root was not consistent in terms of sign. We need something which either gives us sign of the lower limit of exact Root or the upper limit. Roots.jl doesn't give me this consistency. I saw your mention and yeah, if we get an interval instead of exact value, that should solve our issue.

@@ -555,10 +581,11 @@ function find_callback_time(integrator,callback::ContinuousCallback,counter)
iter = 1
while sign(zero_func(bottom_θ)) == sign_top && iter < 12
bottom_θ *= 5
iter += 1
end
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we still need to do all of this, or if we just immediately throw an error if we see this.

@kanav99
Copy link
Contributor Author

kanav99 commented Feb 14, 2020

I guess scope of the PR is complete. Some AD issues need to be addressed. I will open a seperate issue for that probably.
Please squash before merging, history of this PR is messed up 😛

@ChrisRackauckas
Copy link
Member

Let's get some more AD tests before committing to it. That's a pretty core feature that we can't break.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants