-
Notifications
You must be signed in to change notification settings - Fork 95
Add repeat rules from Zygote #460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #460 +/- ##
==========================================
+ Coverage 98.48% 98.58% +0.10%
==========================================
Files 21 21
Lines 2175 2198 +23
==========================================
+ Hits 2142 2167 +25
+ Misses 33 31 -2
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good! Some minor nitpicking follows:
I suspect
test_rrule(repeat, rand(4, 5), 2)would fail (because there is an optional argument in the rule and the new ChainRulesTestUtils version checks that the number of tangents matches the number of arguments)
There are also a few tests in
https://github.com/FluxML/Zygote.jl/blob/1082ebd3aced63b99c4b6c2956a122ce6a37f97d/test/gradcheck.jl#L232-L239
but from a quick look it seems we've covered them?
|
In #383 (comment) there's a complaint that the rule taking I think you can do the general case via broadcasting, I wrote out one implementation here https://gist.github.com/mcabbott/80ac43cca3bee8f57809155a5240519f which you should feel free to steal. But not very carefully checked, it might be worth inventing some evil test cases. |
Yeah I hit this when I was working with it. I should have included a commeont on this. What's the resolution here - If i'm reading this correctly this is a general problem - any function with omitted positional arguments with defaults will fail under this test. I don't think I see any other examples of positional defaults being set within the What is the intended convention here: to return tangents for all specified arguments or all possible arguments (i.e. Alternatively, we can write out these explicitely, dropping the defaults. Leads to a bit of duplication, but fine for these small functions. |
|
I only know of a single case, solved here: @mcabbott as far as I understand it is also slow on Zygote, so we can merge this and the performance will be the same? And then follow up with a better implementation later? Would be good to mention the status on the main issue |
|
Is there a different main issue besides the one I linked? I just think it might be less hassle to test & merge one PR than two. Zygote has a lot of cobbled-together quick-fix code, some of it (like this) accidentally two orders of magnitude slower than it should be. Re optional arguments, that seems like a bug, is there an issue tracking it? Slurped arguments do seem to work (as in |
|
Thanks for your input both. I've added a couple of commits to deal with the comments:
The zero array ones are broken for now though, in two ways: For
For
|
|
Re zero-arrays, it's possible there are issues with whether test_rrule can handle these. It also doesn't seem so obvious what the rule should return -- in lots of other contexts, zero-arrays tend to go away, perhaps they should here too? |
|
Yeah With a local fix for this, it looks like these two This is the 'fix': |
| # These dispatch but rrule needs to be fixed to zero-arrays | ||
| # These dispatch but probably needs | ||
| # https://github.com/JuliaDiff/FiniteDifferences.jl/issues/179 | ||
| # test_rrule(repeat, fill(1.0); fkwargs = (inner=2,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these fixed on your FiniteDifferences branch with the rand_tangent fix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so. Can uncoment these when doing a follow up for the function rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...) general method along with a FiniteDifferences bump. ?
Adds rrules for repeat. These are essentially copied from Zygote as suggested in #383 and hence taken from here
Doesn't include missing rules. At least from what I can tell, this is:
One change from the original code in Zygote: closes instead of the value of the size of
xsrather thanxsitself. Also adds anunthunk.Note, I saw #405 after looking at this (that caught the above), but not sure if that is being actively worked on.