-
Notifications
You must be signed in to change notification settings - Fork 95
[WIP] Add Zygote repeat rrules #405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Will add missing rrules and frules in another PR |
|
Great, thanks! What do you mean by kwarg shenanigans? Could you add the tests as well please? |
|
It looks like you are passing optional arguments rather than kwargs actually, kwargs need |
|
kwargs are getting passed all the way down to isapprox which doesn't seem right. |
|
Oh, I see. Check the |
mzgubic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks good overall! I've added a few style changes and a couple of minor suggestions.
test/rulesets/Base/array.jl
Outdated
| @testset "repeat" begin | ||
| test_rrule(repeat, randn(5), 3) | ||
| test_rrule(repeat, randn(3,3), 2) | ||
| test_rrule(repeat, randn(5,4,3); fkwargs=(inner=(2,2,1), outer=(1,1,3))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a few more tests? E.g. both matrix and vectors for both 1 and two indices. And similarly a few edge cases for the inner outer one?
Perhaps we could also test for Float64[] or fill(4.0), both zero dimensional arrays
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you want the result to be for the empty array? Repeat just returns a zero-dim array for that.
Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
test/rulesets/Base/array.jl
Outdated
| test_rrule(repeat, randn(5), 3) | ||
| test_rrule(repeat, randn(5), 3, 3) | ||
| test_rrule(repeat, randn(3, 3), 2) | ||
| test_rrule(repeat, randn(5, 5), 2,5) | ||
| test_rrule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3))) | ||
| test_rrule(repeat, fill(4.0), 3) | ||
| test_frule(repeat, randn(5), 3) | ||
| test_frule(repeat, randn(5), 3,3) | ||
| test_frule(repeat, randn(3, 3), 2) | ||
| test_frule(repeat, randn(3, 3), 2,5) | ||
| test_frule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3))) | ||
| test_frule(repeat, fill(4.0), 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| test_rrule(repeat, randn(5), 3) | |
| test_rrule(repeat, randn(5), 3, 3) | |
| test_rrule(repeat, randn(3, 3), 2) | |
| test_rrule(repeat, randn(5, 5), 2,5) | |
| test_rrule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3))) | |
| test_rrule(repeat, fill(4.0), 3) | |
| test_frule(repeat, randn(5), 3) | |
| test_frule(repeat, randn(5), 3,3) | |
| test_frule(repeat, randn(3, 3), 2) | |
| test_frule(repeat, randn(3, 3), 2,5) | |
| test_frule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3))) | |
| test_frule(repeat, fill(4.0), 3) | |
| test_rrule(repeat, randn(5), 3) | |
| test_rrule(repeat, randn(5), 3, 3) | |
| test_rrule(repeat, randn(3, 3), 2) | |
| test_rrule(repeat, randn(5, 5), 2,5) | |
| test_rrule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3))) | |
| test_rrule(repeat, fill(4.0), 3) | |
| @testset "frule" begin | |
| test_frule(repeat, fill(4.0), 3) | |
| test_frule(repeat, randn(5), 3) | |
| test_frule(repeat, randn(5), 3, 3) | |
| test_frule(repeat, randn(3, 3), 2) | |
| test_frule(repeat, randn(3, 3), 2, 5) | |
| test_frule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3))) | |
| end |
and similarly for the rrule to keep things tidy
| return repeat(x, m), repeat_pullback | ||
| end | ||
|
|
||
| function rrule(::typeof(repeat), x::AbstractVecOrMat, m::Integer, n::Integer=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since things like repeat(rand(Int8,2,2,2),1,1,2,2) are allowed, it might be nice to treat the general case? Untested, but my attempt is:
function rrule(::typeof(repeat), x::AbstractArray, scales::Integer...)
function repeat_pullback_1(dy)
size2ndims = ntuple(d -> isodd(d) ? size(x,1+d÷2) : get(scales,d÷2,1), 2*ndims(dy))
sumdy = sum(reshape(dy, size2ndims); dims = ntuple(d -> 2d, ndims(dy)))
return (NO_FIELDS, reshape(sumdy, size(x)), map(_->DoesNotExist(), scales)...)
end
return repeat(x, scales...), repeat_pullback_1
end
What this won't handle is repeat(1:2, 3,2,1,0) but perhaps nobody will write that!
| function repeat_pullback(Ȳ) | ||
| return (NO_FIELDS, dropdims(sum(reshape(Ȳ, length(x), :); dims=2); dims=2), DoesNotExist()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this close over length(x) rather than x (to reduce memory requirements)?
| function repeat_pullback(Ȳ) | |
| return (NO_FIELDS, dropdims(sum(reshape(Ȳ, length(x), :); dims=2); dims=2), DoesNotExist()) | |
| length_x = length(x) | |
| function repeat_pullback(Ȳ) | |
| return (NO_FIELDS, dropdims(sum(reshape(Ȳ, length_x, :); dims=2); dims=2), DoesNotExist()) |
And likewise with the size calls in the rules below.
This is a genuine question, i'm not sure what best practice here is (cc @oxinabox).
I believe this would be manually doing one thing we want an OpaqueClosure to be able to do automatically.
Also maybe this should be mentioned on the "writing good rules docs"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I am cautious about doing that since it makes code uglier.
So I wouldn't block a PR over it (it can always be optimized later)
But it does save memory.
We should discuss this with pros and cons in "Writing God Rules"
|
Closed by #460 |
These are as direct of translations as you can get from Zygote. The last rule, however, is not working because of kwargs shenanigans. Not sure how this should work
Edit: Works towards #383