-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Less clever sum rrule for the default case #493
Conversation
(I will leave this PR to @mcabbott to review and approve. I have unsubscribed. Ping me if I am wanted) |
The downside of With this |
Well, we can't use |
Codecov Report
@@ Coverage Diff @@
## master #493 +/- ##
=======================================
Coverage 98.11% 98.12%
=======================================
Files 22 22
Lines 2340 2350 +10
=======================================
+ Hits 2296 2306 +10
Misses 44 44
Continue to review full report at Codecov.
|
The sum rrule relies on broadcasting to figure out the result shape of the cotangent. However, this has two disadvantages: 1. We need to keep the original array around 2. Broadcasting machinery is complicated and tough on (higher-order) AD This adds a special case for `dims=:`, which simply stores the dimensions of the original array and uses `fill` in the pullback, which has a simple rrule and is thus much easier to AD.
It would certainly be easy to make a For the scalar case, it does seem a little wasteful to make a constant dense array at all. Zygote presently makes a
Edit -- on 1.7 another possible lazy structure is a view, this appears to work:
|
Here's one attempt to write this. The fill-like function needn't get axes/size as well as function rrule(::typeof(sum), x::AbstractArray; dims=:)
project = ProjectTo(x)
y = sum(x; dims=dims)
function sum_pullback(dy_raw)
dy = unthunk(dy_raw)
# protect this from broadcasting, when `x` is an array of arrays
dy_safe = dims isa Colon ? Ref(dy) : dy
x_thunk = InplaceableThunk(
dx -> dx .+= dy_safe,
@thunk project(_unsum(x, dy_safe, dims))
)
return (NoTangent(), x_thunk)
end
return y, sum_pullback
end
_unsum(shape, value, dims) = broadcast(last∘tuple, shape, value)
function rrule(::typeof(_unsum), shape, value, dims)
_unsum_pullback(dz) = (NoTangent(), NoTangent(), sum(unthunk(dz); dims=dims), NoTangent())
return _unsum(shape, value, dims), _unsum_pullback
end Edit -- not quite right, but mcabbott@a65f00d is a version with tests. |
That rule may be the best we can do for the moment. Wanna take over this PR? |
Sure. Can I push to this branch, or better to call it a new PR from my fork? Will open another thread to discuss ways around capturing etc, afterwards. |
Either way, you can also just change the branch of this PR. |
Ok, I went the familiar route, and #494 is the new PR. |
The sum rrule relies on broadcasting to figure out the
result shape of the cotangent. However, this has two
disadvantages:
This adds a special case for
dims=:
, which simply storesthe dimensions of the original array and uses
fill
in the pullback, which has a simple rrule and is thus
much easier to AD.