diff --git a/Project.toml b/Project.toml index ac53c6fe7..c25bcada3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.8.1" +version = "1.9" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 1cc67e239..24d0b559c 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -68,14 +68,13 @@ end ##### ##### `repeat` ##### - function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs))) + project_Xs = ProjectTo(xs) + S = size(xs) function repeat_pullback(ȳ) dY = unthunk(ȳ) Δ′ = zero(xs) - S = size(xs) - # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ for (dest_idx, val) in pairs(IndexCartesian(), dY) # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then @@ -83,39 +82,25 @@ function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs) src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] Δ′[src_idx...] += val end - return (NoTangent(), Δ′) + x̄ = project_Xs(Δ′) + return (NoTangent(), x̄) end return repeat(xs; inner = inner, outer = outer), repeat_pullback end -function rrule(::typeof(repeat), xs::AbstractVector, m::Integer) +function rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...) - d1 = size(xs, 1) + project_Xs = ProjectTo(xs) + S = size(xs) function repeat_pullback(ȳ) - Δ′ = dropdims(sum(reshape(ȳ, d1, :); dims=2); dims=2) - return (NoTangent(), Δ′, NoTangent()) - end - - return repeat(xs, m), repeat_pullback -end - -function rrule(::typeof(repeat), xs::AbstractVecOrMat, m::Integer, n::Integer) - d1, d2 = size(xs, 1), size(xs, 2) - function repeat_pullback(ȳ) - ȳ′ = reshape(ȳ, d1, m, d2, n) - return NoTangent(), reshape(sum(ȳ′; dims=(2,4)), (d1, d2)), NoTangent(), NoTangent() + dY = unthunk(ȳ) + size2ndims = ntuple(d -> isodd(d) ? get(S, 1+d÷2, 1) : get(counts, d÷2, 1), 2*ndims(dY)) + reduced = sum(reshape(dY, size2ndims); dims = ntuple(d -> 2d, ndims(dY))) + x̄ = project_Xs(reshape(reduced, S)) + return (NoTangent(), x̄, map(_->NoTangent(), counts)...) end - - return repeat(xs, m, n), repeat_pullback -end - -function rrule(T::typeof(repeat), xs::AbstractVecOrMat, m::Integer) - - # Workaround use of positional default (i.e. repeat(xs, m, n = 1))) - y, full_pb = rrule(T, xs, m, 1) - repeat_pullback(ȳ) = full_pb(ȳ)[1:3] - return y, repeat_pullback + return repeat(xs, counts...), repeat_pullback end ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index a643b64f7..233a85125 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -43,27 +43,35 @@ end end @testset "repeat" begin + test_rrule(repeat, rand(4, )) - test_rrule(repeat, rand(4, ), 2) test_rrule(repeat, rand(4, 5)) test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),)) test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3))) + test_rrule(repeat, rand(4, ), 2; check_inferred=VERSION>=v"1.6") + test_rrule(repeat, rand(4, 5), 2; check_inferred=VERSION>=v"1.6") + test_rrule(repeat, rand(4, 5), 2, 3; check_inferred=VERSION>=v"1.6") + test_rrule(repeat, rand(1,2,3), 2,3,4; check_inferred=VERSION>v"1.6") + test_rrule(repeat, rand(0,2,3), 2,0,4; check_inferred=VERSION>v"1.6") + test_rrule(repeat, rand(1,1,1,1), 2,3,4,5; check_inferred=VERSION>v"1.6") + + if VERSION>=v"1.6" - # repeat([1 2; 3 4], inner=(2,4), outer=(1,1,1,3)) fails for v<1.6 + # These are cases where repeat itself fails in earlier versions test_rrule(repeat, rand(4, 5); fkwargs = (inner=(2,4), outer=(1,1,1,3))) - end - test_rrule(repeat, rand(4, 5), 2; check_inferred=VERSION>=v"1.5") - test_rrule(repeat, rand(4, 5), 2, 3) + test_rrule(repeat, rand(1,2,3), 2,3) + test_rrule(repeat, rand(1,2,3), 2,3,4,2) + test_rrule(repeat, fill(1.0), 2) + test_rrule(repeat, fill(1.0), 2, 3) - # zero-arrays: broken - @test_broken rrule(repeat, fill(1.0), 2) !== nothing - @test_broken rrule(repeat, fill(1.0), 2, 3) !== nothing + # These fail for other v1.0 related issues (add!!) + # v"1.0": fill(1.0) + fill(1.0) != fill(2.0) + # v"1.6: fill(1.0) + fill(1.0) == fill(2.0) # Expected + test_rrule(repeat, fill(1.0); fkwargs = (inner=2,)) + test_rrule(repeat, fill(1.0); fkwargs = (inner=2, outer=3,)) - # These dispatch but probably needs - # https://github.com/JuliaDiff/FiniteDifferences.jl/issues/179 - # test_rrule(repeat, fill(1.0); fkwargs = (inner=2,)) - # test_rrule(repeat, fill(1.0); fkwargs = (inner=2, outer=3,)) + end @test rrule(repeat, [1,2,3], 4)[2](ones(12))[2] == [4,4,4] @test rrule(repeat, [1,2,3], outer=4)[2](ones(12))[2] == [4,4,4]