Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Less clever sum rrule for the default case #493

Closed
wants to merge 1 commit into from

Conversation

Keno
Copy link
Contributor

@Keno Keno commented Aug 2, 2021

The sum rrule relies on broadcasting to figure out the
result shape of the cotangent. However, this has two
disadvantages:

  1. We need to keep the original array around
  2. Broadcasting machinery is complicated and tough on (higher-order) AD

This adds a special case for dims=:, which simply stores
the dimensions of the original array and uses fill
in the pullback, which has a simple rrule and is thus
much easier to AD.

@oxinabox
Copy link
Member

oxinabox commented Aug 2, 2021

(I will leave this PR to @mcabbott to review and approve. I have unsubscribed. Ping me if I am wanted)

@mcabbott
Copy link
Member

mcabbott commented Aug 2, 2021

The downside of fill is that it's not going to preserve a CuArray. If the type is the same you can do fill!(similar(typeof(x), axes(x)), dy), but CR allows for the eltype to vary, and then you need x for fill!(similar(x, typeof(dy), axes(x)), dy) I think.

With this ProjectTo story, HWFloats like x::AbstractArray{Float32} imply that the eltype is fixed -- this is known in the forward pass, so could in principle have a happy path. But a trait or something to tell you this isn't really set up yet.

@Keno
Copy link
Contributor Author

Keno commented Aug 2, 2021

Well, we can't use fill! in the actual rule because Diffractor doesn't like mutation, but we can define a helper function that does that and just give it an appropriate rrule. In theory there should have been a fill-like array constructor, but JuliaLang/julia#24595 was never finished.

@codecov-commenter
Copy link

codecov-commenter commented Aug 2, 2021

Codecov Report

Merging #493 (be67fbc) into master (7593339) will increase coverage by 0.00%.
The diff coverage is 100.00%.

❗ Current head be67fbc differs from pull request most recent head 2d72b50. Consider uploading reports for the commit 2d72b50 to get more accurate results
Impacted file tree graph

@@           Coverage Diff           @@
##           master     #493   +/-   ##
=======================================
  Coverage   98.11%   98.12%           
=======================================
  Files          22       22           
  Lines        2340     2350   +10     
=======================================
+ Hits         2296     2306   +10     
  Misses         44       44           
Impacted Files Coverage Δ
src/rulesets/Base/mapreduce.jl 99.12% <100.00%> (+0.08%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7593339...2d72b50. Read the comment docs.

The sum rrule relies on broadcasting to figure out the
result shape of the cotangent. However, this has two
disadvantages:

1. We need to keep the original array around
2. Broadcasting machinery is complicated and tough on (higher-order) AD

This adds a special case for `dims=:`, which simply stores
the dimensions of the original array and uses `fill`
in the pullback, which has a simple rrule and is thus
much easier to AD.
@mcabbott
Copy link
Member

mcabbott commented Aug 2, 2021

It would certainly be easy to make a broadfill(x, dy) which does what the rule now does, and has its own gradient. Still closes over x.

For the scalar case, it does seem a little wasteful to make a constant dense array at all. Zygote presently makes a Fill though it avoids this for CuArrays. Ideally this would work everywhere, right?

julia> cu(ones(5))' .+ Fill(3f0, 2) .+ 6
2×5 CuArray{Float32, 2}:
 10.0  10.0  10.0  10.0  10.0
 10.0  10.0  10.0  10.0  10.0

Edit -- on 1.7 another possible lazy structure is a view, this appears to work:

julia> view(cu(Float32[1 2 3]), 0 .* (1:4) .+ 1, :)
4×3 view(::CuArray{Float32, 2}, StepRangeLen(1, 0, 4), :) with eltype Float32:
 1.0  2.0  3.0
 1.0  2.0  3.0
 1.0  2.0  3.0
 1.0  2.0  3.0

julia> sqrt.(ans) isa CuArray
true

@mcabbott
Copy link
Member

mcabbott commented Aug 3, 2021

Here's one attempt to write this. The fill-like function needn't get axes/size as well as x, but if it gets dims then it can cater to both the complete sum and the partial one. I think this ought to handle arrays of arrays, structured matrices, CuArrays, etc, but haven't tested much.

function rrule(::typeof(sum), x::AbstractArray; dims=:)
    project = ProjectTo(x)
    y = sum(x; dims=dims)
    function sum_pullback(dy_raw)
        dy = unthunk(dy_raw)
        # protect this from broadcasting, when `x` is an array of arrays
        dy_safe = dims isa Colon ? Ref(dy) : dy
        x_thunk = InplaceableThunk(
            dx -> dx .+= dy_safe,
            @thunk project(_unsum(x, dy_safe, dims))
        )
        return (NoTangent(), x_thunk)
    end
    return y, sum_pullback
end

_unsum(shape, value, dims) = broadcast(lasttuple, shape, value)

function rrule(::typeof(_unsum), shape, value, dims)
    _unsum_pullback(dz) = (NoTangent(), NoTangent(), sum(unthunk(dz); dims=dims), NoTangent())
    return _unsum(shape, value, dims), _unsum_pullback
end

Edit -- not quite right, but mcabbott@a65f00d is a version with tests.

@Keno
Copy link
Contributor Author

Keno commented Aug 3, 2021

That rule may be the best we can do for the moment. Wanna take over this PR?

@mcabbott
Copy link
Member

mcabbott commented Aug 3, 2021

Sure. Can I push to this branch, or better to call it a new PR from my fork?

Will open another thread to discuss ways around capturing etc, afterwards.

@Keno
Copy link
Contributor Author

Keno commented Aug 3, 2021

Either way, you can also just change the branch of this PR.

@mcabbott
Copy link
Member

mcabbott commented Aug 3, 2021

Ok, I went the familiar route, and #494 is the new PR.

@Keno Keno closed this Aug 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants