Skip to content

Conversation

@AlexRobson
Copy link
Member

@AlexRobson AlexRobson commented Jul 4, 2021

Follow up from #460

Adds in the missing function rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...) method.

Uses the code provided by @mcabbott in the gist here

Some comments:

  • There were some type infererence issues in running in 1.0 IDK how important this is is (a bunch of stuff is marked to only check inference in later versions), but to avoid regressing on the tests already implemented I've added some type hinting in the closure. It's a bit gross so there may be a better way to do it. EDIT: Now removed and just relaxing the 1.0 tests for typing.
  • Bumps finite differences so rand_tangent{Array{T, 0}} is now doing the correct thing and hence enables those tests. There are some differences in the behaviour of the forward pass of repeat between 1.0 and 1.6 (the CI versions) so there is a bit of version checking. FD has been bumped, but actually no longer needed now that rand_tangent is in CRTU

@AlexRobson AlexRobson force-pushed the ar/add_missing_repeat branch from 78b13b0 to 1cd6bb3 Compare July 5, 2021 09:28
@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jul 5, 2021
@codecov-commenter
Copy link

codecov-commenter commented Jul 5, 2021

Codecov Report

Merging #466 (1a23643) into master (0d55c54) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #466      +/-   ##
==========================================
- Coverage   98.12%   98.12%   -0.01%     
==========================================
  Files          22       22              
  Lines        2350     2348       -2     
==========================================
- Hits         2306     2304       -2     
  Misses         44       44              
Impacted Files Coverage Δ
src/rulesets/Base/array.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0d55c54...1a23643. Read the comment docs.

@mcabbott
Copy link
Member

mcabbott commented Jul 6, 2021

now doing the correct thing

Meaning what?

In #460 (comment) I was trying to say that it seems a bit of an open question what should be returned for zero-arrays. They shouldn't error, but should they always be mapped to scalars, or always restored, or is a random mix OK?

@AlexRobson AlexRobson force-pushed the ar/add_missing_repeat branch from cb4bad2 to dad6dcb Compare July 8, 2021 14:10
@github-actions github-actions bot removed the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jul 8, 2021
@AlexRobson
Copy link
Member Author

now doing the correct thing

Meaning what?

In #460 (comment) I was trying to say that it seems a bit of an open question what should be returned for zero-arrays. They shouldn't error, but should they always be mapped to scalars, or always restored, or is a random mix OK?

So the issue I was encountering was in TestUtils and the use of rand_tangent, which essentially led to the test input to the pullback being a scalar, over a zero-vector. I don't have the stack trace at hand, but by including these zero-array tests as edge-cases, It was essentially these type conversions which was what was causing these to error. I believe we would either need to promote (somewhere) the zero-vector to a scalar, or as what's been implemented here (or rather in FiniteDifferences), allow the zero-array to propagate.

I suppose that my stating that it is doing the correct thing is a bit presumptive. What the 'right' choice of the zero-array idiom is idk. Julia itself seems to have changed the behaviour between versions, i.e. in 1.0 fill(1.0) + fill(1.0) = 2.0, and in 1.6 it is fill(2.0). Similar for multiplication - in 1.6 it propagates through (fill(1.0) * 5 == fill(5.0)` (but not if you broadcast 🤷 ).

So yeah, I'm not certain on what behaviour is wanted. FWIW, I'd expect it to propagate: consistently treat a zero-vector as a vector, and a scalar as a scalar. But idk.

@mcabbott
Copy link
Member

mcabbott commented Jul 8, 2021

FWIW, I'd expect it to propagate: consistently treat a zero-vector as a vector

I'm torn actually. This has the plus side of being easy to state. And sometimes, a zero-array might be a sensible one-number container when you need something mutable and don't want Ref? (But why not Ref?)

On the minus side, it's a much heavier object than Ref or Float64. And it's a pain to preserve in all rules, since e.g. A[1,:,1] is a vector but A[1,1,1] is not a zero-array. If you try to restore them from a number, e.g. like this JuliaDiff/ChainRulesCore.jl@d0419b3, then fill(1.2) will always make an Array even if you had say a CuArray, which will not go well, whereas a number might.

@AlexRobson AlexRobson force-pushed the ar/add_missing_repeat branch from dad6dcb to a0c5207 Compare August 12, 2021 14:02
@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Aug 12, 2021
@AlexRobson AlexRobson changed the title WIP: Add missing generic repeat method Add missing generic repeat method Aug 12, 2021
@github-actions github-actions bot removed the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Aug 12, 2021
@AlexRobson
Copy link
Member Author

Coming back to this MR.

I've now added in Project for xs and wrapped the calculation of x̄ into a thunk and rebased onto master.

Responding to your comment (if it's still applies!)

Regarding the consistency - behaviour of zero-arrays, Ref and so on and the nuisance inconsistencies in terms of whether it's treated as a vector or a real this seems a broader question than related to repeat. However, if I understand correctly, for where inconsistencies do emerge, Project should allow for some robustness, at least from the perspective of the output of Xrules, as project can just convert the real back into the zero-array (or whatever). So if the pullback did happen to e.g go from a zero-array -> a real number, project should at least make it robust to that.

However, there may be cases where these inconstencies interfere with the internals of CVC (such as add!!) and break. Zero-arrays do have a habit of going missing as is commented (e.g. in 1.0: fill(1.0) + fill(1.0) = 2.0). This does seem to cause an issue here because add!! is trying to add a zero-array vector to a real in the add!! test. I'm guessing from your comment that you may have seen something like this with CuArrays. If this is the case, this seems to warrant a separate issue.

@AlexRobson AlexRobson requested a review from mzgubic August 13, 2021 09:17
Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM bar a few minor comments. Are you and @mcabbott happy with the Fill outcome?

@AlexRobson AlexRobson force-pushed the ar/add_missing_repeat branch from daf52c4 to 5fdd9eb Compare August 18, 2021 12:43
@AlexRobson AlexRobson merged commit 1fb2761 into master Aug 18, 2021
@AlexRobson AlexRobson deleted the ar/add_missing_repeat branch August 18, 2021 14:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants