diff --git a/Project.toml b/Project.toml index 0060000bb..e792543b0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.8.17" +version = "0.8.18" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 1853dac8a..f23957ebc 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -20,6 +20,59 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...) return reshape(A, dims...), reshape_pullback end +##### +##### `repeat` +##### + +function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs))) + + function repeat_pullback(ȳ) + dY = unthunk(ȳ) + Δ′ = zero(xs) + S = size(xs) + + # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ + for (dest_idx, val) in pairs(IndexCartesian(), dY) + # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then + # wrap around based on original size S. + src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] + Δ′[src_idx...] += val + end + return (NoTangent(), Δ′) + end + + return repeat(xs; inner = inner, outer = outer), repeat_pullback +end + +function rrule(::typeof(repeat), xs::AbstractVector, m::Integer) + + d1 = size(xs, 1) + function repeat_pullback(ȳ) + Δ′ = dropdims(sum(reshape(ȳ, d1, :); dims=2); dims=2) + return (NoTangent(), Δ′, NoTangent()) + end + + return repeat(xs, m), repeat_pullback +end + +function rrule(::typeof(repeat), xs::AbstractVecOrMat, m::Integer, n::Integer) + d1, d2 = size(xs, 1), size(xs, 2) + function repeat_pullback(ȳ) + ȳ′ = reshape(ȳ, d1, m, d2, n) + return NoTangent(), reshape(sum(ȳ′; dims=(2,4)), (d1, d2)), NoTangent(), NoTangent() + end + + return repeat(xs, m, n), repeat_pullback +end + +function rrule(T::typeof(repeat), xs::AbstractVecOrMat, m::Integer) + + # Workaround use of positional default (i.e. repeat(xs, m, n = 1))) + y, full_pb = rrule(T, xs, m, 1) + repeat_pullback(ȳ) = full_pb(ȳ)[1:3] + return y, repeat_pullback +end + ##### ##### `hcat` ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index a3e5234cf..eb8e8cd0f 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -4,6 +4,34 @@ test_rrule(reshape, rand(4, 5), 2, :) end +@testset "repeat" begin + test_rrule(repeat, rand(4, )) + test_rrule(repeat, rand(4, ), 2) + test_rrule(repeat, rand(4, 5)) + test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),)) + test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3))) + + if VERSION>=v"1.6" + # repeat([1 2; 3 4], inner=(2,4), outer=(1,1,1,3)) fails for v<1.6 + test_rrule(repeat, rand(4, 5); fkwargs = (inner=(2,4), outer=(1,1,1,3))) + end + test_rrule(repeat, rand(4, 5), 2; check_inferred=VERSION>=v"1.5") + test_rrule(repeat, rand(4, 5), 2, 3) + + # zero-arrays: broken + @test_broken rrule(repeat, fill(1.0), 2) !== nothing + @test_broken rrule(repeat, fill(1.0), 2, 3) !== nothing + + # These dispatch but probably needs + # https://github.com/JuliaDiff/FiniteDifferences.jl/issues/179 + # test_rrule(repeat, fill(1.0); fkwargs = (inner=2,)) + # test_rrule(repeat, fill(1.0); fkwargs = (inner=2, outer=3,)) + + @test rrule(repeat, [1,2,3], 4)[2](ones(12))[2] == [4,4,4] + @test rrule(repeat, [1,2,3], outer=4)[2](ones(12))[2] == [4,4,4] + +end + @testset "hcat" begin test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3); check_inferred=VERSION>v"1.1") test_rrule(hcat, rand(), rand(1,2), rand(1,2,1); check_inferred=VERSION>v"1.1")