Skip to content

Conversation

@AlexRobson
Copy link
Member

@AlexRobson AlexRobson commented Jun 29, 2021

Adds rrules for repeat. These are essentially copied from Zygote as suggested in #383 and hence taken from here

Doesn't include missing rules. At least from what I can tell, this is:

function rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...)

One change from the original code in Zygote: closes instead of the value of the size of xs rather than xs itself. Also adds an unthunk.

Note, I saw #405 after looking at this (that caught the above), but not sure if that is being actively worked on.

@codecov-commenter
Copy link

codecov-commenter commented Jun 30, 2021

Codecov Report

Merging #460 (e18cec7) into master (bbb88f7) will increase coverage by 0.10%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #460      +/-   ##
==========================================
+ Coverage   98.48%   98.58%   +0.10%     
==========================================
  Files          21       21              
  Lines        2175     2198      +23     
==========================================
+ Hits         2142     2167      +25     
+ Misses         33       31       -2     
Impacted Files Coverage Δ
src/rulesets/Base/array.jl 100.00% <100.00%> (ø)
src/rulesets/LinearAlgebra/factorization.jl 96.78% <0.00%> (-0.30%) ⬇️
src/rulesets/Base/evalpoly.jl 100.00% <0.00%> (+2.15%) ⬆️
src/rulesets/Base/nondiff.jl 100.00% <0.00%> (+33.33%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bbb88f7...e18cec7. Read the comment docs.

@AlexRobson AlexRobson changed the title WIP: Add repeat rules from Zygote Add repeat rules from Zygote Jun 30, 2021
Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good! Some minor nitpicking follows:

I suspect

test_rrule(repeat, rand(4, 5), 2)

would fail (because there is an optional argument in the rule and the new ChainRulesTestUtils version checks that the number of tangents matches the number of arguments)

There are also a few tests in
https://github.com/FluxML/Zygote.jl/blob/1082ebd3aced63b99c4b6c2956a122ce6a37f97d/test/gradcheck.jl#L232-L239
but from a quick look it seems we've covered them?

@mcabbott
Copy link
Member

In #383 (comment) there's a complaint that the rule taking inner, outer keywords is pretty slow. And the fast rule is only written for a few specific ndims.

I think you can do the general case via broadcasting, I wrote out one implementation here https://gist.github.com/mcabbott/80ac43cca3bee8f57809155a5240519f which you should feel free to steal. But not very carefully checked, it might be worth inventing some evil test cases.

@AlexRobson
Copy link
Member Author

AlexRobson commented Jun 30, 2021

I suspect

test_rrule(repeat, rand(4, 5), 2)

would fail (because there is an optional argument in the rule and the new ChainRulesTestUtils version checks that the number of tangents matches the number of arguments)

Yeah I hit this when I was working with it. I should have included a commeont on this. What's the resolution here - If i'm reading this correctly this is a general problem - any function with omitted positional arguments with defaults will fail under this test. I don't think I see any other examples of positional defaults being set within the rrules, so perhaps it has just not come up?

What is the intended convention here: to return tangents for all specified arguments or all possible arguments (i.e. for foo(x, y = 5) = x + 5), should this return 1 or 2 tangents (ignoring ∂self) if invoked by e.g. foo(3), given it'll always end up dispatching to the full expression.

Alternatively, we can write out these explicitely, dropping the defaults. Leads to a bit of duplication, but fine for these small functions.

function rrule(::typeof(repeat), xs::AbstractVecOrMat, m::Integer, n::Integer)...
function rrule(::typeof(repeat), xs::AbstractVecOrMat, m::Integer)...
    ...
    return repeat(xs, m, 1), ...
end

@mzgubic
Copy link
Member

mzgubic commented Jun 30, 2021

I only know of a single case, solved here:
https://github.com/JuliaDiff/ChainRules.jl/pull/436/files

@mcabbott as far as I understand it is also slow on Zygote, so we can merge this and the performance will be the same? And then follow up with a better implementation later? Would be good to mention the status on the main issue

@mcabbott
Copy link
Member

mcabbott commented Jun 30, 2021

Is there a different main issue besides the one I linked?

I just think it might be less hassle to test & merge one PR than two. Zygote has a lot of cobbled-together quick-fix code, some of it (like this) accidentally two orders of magnitude slower than it should be.

Re optional arguments, that seems like a bug, is there an issue tracking it? Slurped arguments do seem to work (as in repeat([1], 2,3,4,5)). Edit -- OK I see, it's not a bug, it's just that an optional argument writes two methods, and each must return as many tangents as it has arguments.

@AlexRobson
Copy link
Member Author

Thanks for your input both. I've added a couple of commits to deal with the comments:

  • I've used the same approach as used in @mzgubic branch to handle the positional arguments. It is essentially as descrbed above by dropping the positional arguments and using separate methods for them.
  • I've added the extra tests as suggested by @mcabbott.

The zero array ones are broken for now though, in two ways:

For rrule(repeat, fill(1.0), 2) and friends:

  • Right now this implementation uses AbstractVector and AbbstractVecOrMat so they are not dispatching. I believe this should be fixed when we add function rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...).

For rrule(repeat, fill(1.0); fkwargs = (inner=2,))) and friends:

  • This does dispatch, but this is hitting another issue, which I think means the closure in the rrule would need tweaking. Specifically, test_rrule ends up dispatching to _test_add!!_behaviour(acc::Float64, val::Array{Float64, 0} and then failing with MethodError: no method matching +(::Float64, ::Array{Float64, 0}). Basically, playing around with it a bit locally, I think the dimensions in the output of the pullback are off with zero arrays.

@mcabbott
Copy link
Member

mcabbott commented Jun 30, 2021

Re zero-arrays, it's possible there are issues with whether test_rrule can handle these. It also doesn't seem so obvious what the rule should return -- in lots of other contexts, zero-arrays tend to go away, perhaps they should here too?

julia> gradient(x -> x .+ 1, fill(2))
(1,)

julia> gradient(x -> sum(vcat(x, x)), fill(3))
(2,)

julia> gradient(x -> sum(repeat(x, inner=3)), fill(4))
(fill(3),)

@AlexRobson
Copy link
Member Author

Yeah fill functions are seemingly erratic in how they treated. FWIW, I think in this case it's due to how rand_tangent behaves. In this case, it gets converted to a Float64, rather than remaining as an Array{Float64, 0}.

using Test
using FiniteDifferences

x = rand(2,2)
@test typeof(rand_tangent(x)) == typeof(x) # Passes
x = rand(1,)
@test typeof(rand_tangent(x)) == typeof(x) # Passes
x = fill(rand())
@test typeof(rand_tangent(x)) == typeof(x) # Fails

With a local fix for this, it looks like these two rrules are now passing:

test_rrule(repeat, fill(1.0); fkwargs = (inner=2,))
test_rrule(repeat, fill(1.0); fkwargs = (inner=2, outer=3,))

This is the 'fix':

rand_tangent(rng::AbstractRNG, x::StridedArray{T, 0}) where {T} = fill(rand_tangent(only(x)))

# These dispatch but rrule needs to be fixed to zero-arrays
# These dispatch but probably needs
# https://github.com/JuliaDiff/FiniteDifferences.jl/issues/179
# test_rrule(repeat, fill(1.0); fkwargs = (inner=2,))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these fixed on your FiniteDifferences branch with the rand_tangent fix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so. Can uncoment these when doing a follow up for the function rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...) general method along with a FiniteDifferences bump. ?

@AlexRobson AlexRobson merged commit 93eb113 into master Jul 1, 2021
@AlexRobson AlexRobson deleted the ar/add_repeat branch July 1, 2021 11:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants