Skip to content

Commit

Permalink
Less clever sum rrule for the default case
Browse files Browse the repository at this point in the history
The sum rrule relies on broadcasting to figure out the
result shape of the cotangent. However, this has two
disadvantages:

1. We need to keep the original array around
2. Broadcasting machinery is complicated and tough on (higher-order) AD

This adds a special case for `dims=:`, which simply stores
the dimensions of the original array and uses `fill`
in the pullback, which has a simple rrule and is thus
much easier to AD.
  • Loading branch information
Keno committed Aug 2, 2021
1 parent 7593339 commit 2d72b50
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,35 @@ function frule((_, ẋ), ::typeof(sum), x; dims=:)
return sum(x; dims=dims), sum(ẋ; dims=dims)
end

function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
# Internal helper for filling while maintaining array type.
# TODO: Ideallty we'd only need typeof(x) here, but Base doesn't have the
# interfaces for that.
function _typed_fill(x, ȳ, axes)
fill!(similar(x, typeof(ȳ), axes), ȳ)
end

function rrule(::typeof(_typed_fill), x, ȳ, axes)
function _typed_fill_pullback(Ȳ)
return (NoTangent(), NoTangent(), sum(Ȳ), NoTangent())
end
return _typed_fill(x, ȳ, axes), _typed_fill_pullback
end

function sum_rrule(x, dims::Colon)
y = sum(x; dims=dims)
let xdims=size(x)
function sum_pullback(ȳ)
= InplaceableThunk(
x -> x .+= ȳ,
@thunk(_typed_fill(x, ȳ, xdims...)),
)
return (NoTangent(), x̄)
end
y, sum_pullback
end
end

function sum_rrule(x, dims)
y = sum(x; dims=dims)
function sum_pullback(ȳ)
# broadcasting the two works out the size no-matter `dims`
Expand All @@ -19,6 +47,10 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
return y, sum_pullback
end

function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
return sum_rrule(x, dims)
end

# Can't map over Adjoint/Transpose Vector
function rrule(
config::RuleConfig{>:HasReverseMode},
Expand Down

0 comments on commit 2d72b50

Please sign in to comment.