Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.8.17"
version = "0.8.18"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
46 changes: 46 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...)
return reshape(A, dims...), reshape_pullback
end

#####
##### `repeat`
#####

function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs)))

function repeat_pullback(ȳ)
dY = unthunk(ȳ)
Δ′ = zero(xs)
S = size(xs)

# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in pairs(IndexCartesian(), dY)
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
end
return (NoTangent(), Δ′)
end

return repeat(xs; inner = inner, outer = outer), repeat_pullback
end

function rrule(::typeof(repeat), xs::AbstractVector, m::Integer)

d1 = size(xs, 1)
function repeat_pullback(ȳ)
Δ′ = dropdims(sum(reshape(ȳ, d1, :); dims=2); dims=2)
return (NoTangent(), Δ′, NoTangent())
end

return repeat(xs, m), repeat_pullback
end

function rrule(::typeof(repeat), xs::AbstractVecOrMat, m::Integer, n::Integer=1)

d1, d2 = size(xs, 1), size(xs, 2)
function repeat_pullback(ȳ)
ȳ′ = reshape(ȳ, d1, m, d2, n)
return NoTangent(), reshape(sum(ȳ′; dims=(2,4)), (d1, d2)), NoTangent(), NoTangent()
end

return repeat(xs, m, n), repeat_pullback
end

#####
##### `hcat`
#####
Expand Down
8 changes: 8 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
test_rrule(reshape, rand(4, 5), 2, :)
end

@testset "repeat" begin
test_rrule(repeat, rand(4, ))
test_rrule(repeat, rand(4, ), 2)
test_rrule(repeat, rand(4, 5))
test_rrule(repeat, rand(4, 5); fkwargs = (inner = (2,1), outer = (2,2)))
test_rrule(repeat, rand(4, 5), 2, 3)
end

@testset "hcat" begin
test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3); check_inferred=VERSION>v"1.1")
test_rrule(hcat, rand(), rand(1,2), rand(1,2,1); check_inferred=VERSION>v"1.1")
Expand Down