diff --git a/Project.toml b/Project.toml index d300c1475..9a732804b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.61" +version = "0.7.62" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 3b3996373..b65448e04 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -103,3 +103,52 @@ function rrule(::typeof(fill), value::Any, dims::Int...) end return fill(value, dims), fill_pullback end + +##### +##### `repeat` +##### + +function rrule(::typeof(repeat), x::AbstractVector, m::Integer) + function repeat_pullback(Ȳ) + return (NO_FIELDS, dropdims(sum(reshape(Ȳ, length(x), :); dims=2); dims=2), DoesNotExist()) + end + return repeat(x, m), repeat_pullback +end + +function rrule(::typeof(repeat), x::AbstractVecOrMat, m::Integer, n::Integer=1) + function repeat_pullback(Ȳ) + Ȳ′ = reshape(Ȳ, size(x, 1), m, size(x, 2), n) + return (NO_FIELDS, reshape(sum(Ȳ′; dims=(2,4)), size(x)), DoesNotExist(), DoesNotExist()) + end + return repeat(x, m, n), repeat_pullback + end + +function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs))) + function repeat_pullback(Ȳ) + Ȳ′ = zero(xs) + S = size(xs) + for (dest_idx, val) ∈ pairs(IndexCartesian(), Ȳ) + src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim ∈ 1:length(S)] + Ȳ′[src_idx...] += val + end + return (NO_FIELDS, Ȳ′) + end + return repeat(xs; inner=inner, outer=outer), repeat_pullback +end + +function rrule(::typeof(repeat), x::AbstractArray{<:Real, 0}, m::Integer) + repeat_pullback(Ȳ) = (NO_FIELDS, similar(x, eltype(Ȳ)) .= sum(Ȳ), DoesNotExist()) + return repeat(x, m), repeat_pullback +end + +function frule((_, Δx), ::typeof(repeat), x, m::Integer) + return repeat(x, m), repeat(Δx, m) +end + +function frule((_, Δxs), ::typeof(repeat), xs; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs))) + return repeat(xs; inner=inner, outer=outer), repeat(Δxs; inner=inner, outer=outer) +end + +function frule((_, Δx), ::typeof(repeat), x::AbstractArray{<:Real,0}, m::Integer) + return repeat(x, m), repeat(fill(Δx,m)) +end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 16e6e265d..f6692fcfd 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -35,3 +35,23 @@ end test_rrule(fill, 44.0, 4; check_inferred=false) test_rrule(fill, 2.0, (3, 3, 3) ⊢ DoesNotExist()) end + +@testset "repeat" begin + @testset "rrule" begin + test_rrule(repeat, randn(5), 3) + test_rrule(repeat, randn(5), 3, 3) + test_rrule(repeat, randn(3, 3), 2) + test_rrule(repeat, randn(5, 5), 2,5) + test_rrule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3))) + test_rrule(repeat, fill(4.0), 3) + end + + @testset "frule" begin + test_frule(repeat, randn(5), 3) + test_frule(repeat, randn(5), 3,3) + test_frule(repeat, randn(3, 3), 2) + test_frule(repeat, randn(3, 3), 2,5) + test_frule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3))) + test_frule(repeat, fill(4.0), 3) + end +end