Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.61"
version = "0.7.62"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
49 changes: 49 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,52 @@ function rrule(::typeof(fill), value::Any, dims::Int...)
end
return fill(value, dims), fill_pullback
end

#####
##### `repeat`
#####

function rrule(::typeof(repeat), x::AbstractVector, m::Integer)
function repeat_pullback(Ȳ)
return (NO_FIELDS, dropdims(sum(reshape(Ȳ, length(x), :); dims=2); dims=2), DoesNotExist())
end
return repeat(x, m), repeat_pullback
end

function rrule(::typeof(repeat), x::AbstractVecOrMat, m::Integer, n::Integer=1)
function repeat_pullback(Ȳ)
Ȳ′ = reshape(Ȳ, size(x, 1), m, size(x, 2), n)
return (NO_FIELDS, reshape(sum(Ȳ′; dims=(2,4)), size(x)), DoesNotExist(), DoesNotExist())
end
return repeat(x, m, n), repeat_pullback
end

function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs)))
function repeat_pullback(Ȳ)
Ȳ′ = zero(xs)
S = size(xs)
for (dest_idx, val) ∈ pairs(IndexCartesian(), Ȳ)
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim ∈ 1:length(S)]
Ȳ′[src_idx...] += val
end
return (NO_FIELDS, Ȳ′)
end
return repeat(xs; inner=inner, outer=outer), repeat_pullback
end

function rrule(::typeof(repeat), x::AbstractArray{<:Real, 0}, m::Integer)
repeat_pullback(Ȳ) = (NO_FIELDS, similar(x, eltype(Ȳ)) .= sum(Ȳ), DoesNotExist())
return repeat(x, m), repeat_pullback
end

function frule((_, Δx), ::typeof(repeat), x, m::Integer)
return repeat(x, m), repeat(Δx, m)
end

function frule((_, Δxs), ::typeof(repeat), xs; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs)))
return repeat(xs; inner=inner, outer=outer), repeat(Δxs; inner=inner, outer=outer)
end

function frule((_, Δx), ::typeof(repeat), x::AbstractArray{<:Real,0}, m::Integer)
return repeat(x, m), repeat(fill(Δx,m))
end
Comment on lines +144 to +154
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the rules not covered by #460, which only did rrules. PR could be rebased to still add these.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look at this around Wednesday this week after finishing my late JuliaCon vid. Sorry I lost track of them

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No rush, was just closing stale PRs & thought this one should get a note for whenever someone looks.

20 changes: 20 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,23 @@ end
test_rrule(fill, 44.0, 4; check_inferred=false)
test_rrule(fill, 2.0, (3, 3, 3) ⊢ DoesNotExist())
end

@testset "repeat" begin
@testset "rrule" begin
test_rrule(repeat, randn(5), 3)
test_rrule(repeat, randn(5), 3, 3)
test_rrule(repeat, randn(3, 3), 2)
test_rrule(repeat, randn(5, 5), 2,5)
test_rrule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3)))
test_rrule(repeat, fill(4.0), 3)
end

@testset "frule" begin
test_frule(repeat, randn(5), 3)
test_frule(repeat, randn(5), 3,3)
test_frule(repeat, randn(3, 3), 2)
test_frule(repeat, randn(3, 3), 2,5)
test_frule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3)))
test_frule(repeat, fill(4.0), 3)
end
end