Skip to content

Commit 7593339

Browse files
Some Array rules (#491)
* Bump minor version * Add undef Array constructor * construct Array from existing AbstractArray * vect implementation * Bump precision * Additional tests Co-authored-by: Michael Abbott <[email protected]> * Fix undef tests * Constraint vect implementation * Add float-only test * Link to non_differentiable tests issue * Type-stable `vect` implementation Co-authored-by: Michael Abbott <[email protected]> * type stable vect pullback * Update src/rulesets/Base/array.jl Co-authored-by: Michael Abbott <[email protected]> * Test Union{Number, AbstractArray} * Don't test inference below 1.6 * Update src/rulesets/Base/array.jl Co-authored-by: Michael Abbott <[email protected]> * Update src/rulesets/Base/array.jl Co-authored-by: Michael Abbott <[email protected]> * Style fix * Style fix * Update src/rulesets/Base/array.jl Co-authored-by: Michael Abbott <[email protected]> * Style fix * Add an extra test Co-authored-by: Michael Abbott <[email protected]>
1 parent 6df028a commit 7593339

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

src/rulesets/Base/array.jl

+39
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,42 @@
1+
#####
2+
##### constructors
3+
#####
4+
5+
ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...)
6+
7+
function rrule(::Type{T}, x::AbstractArray) where {T<:Array}
8+
project_x = ProjectTo(x)
9+
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
10+
return T(x), Array_pullback
11+
end
12+
13+
#####
14+
##### `vect`
15+
#####
16+
17+
@non_differentiable Base.vect()
18+
19+
# Case of uniform type `T`: the data passes straight through,
20+
# so no projection should be required.
21+
function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
22+
vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...)
23+
return Base.vect(X...), vect_pullback
24+
end
25+
26+
# Numbers and arrays are often promoted, to make a uniform vector.
27+
# ProjectTo here reverses this
28+
function rrule(
29+
::typeof(Base.vect),
30+
X::Vararg{Union{Number,AbstractArray{<:Number}}, N},
31+
) where {N}
32+
projects = map(ProjectTo, X)
33+
function vect_pullback(ȳ)
34+
= ntuple(n -> projects[n](ȳ[n]), N)
35+
return (NoTangent(), X̄...)
36+
end
37+
return Base.vect(X...), vect_pullback
38+
end
39+
140
#####
241
##### `reshape`
342
#####

test/rulesets/Base/array.jl

+37
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,40 @@
1+
@testset "constructors" begin
2+
3+
# We can't use test_rrule here (as it's currently implemented) because the elements of
4+
# the array have arbitrary values. The only thing we can do is ensure that we're getting
5+
# `ZeroTangent`s back, and that the forwards pass produces the correct thing still.
6+
# Issue: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/202
7+
@testset "undef" begin
8+
val, pullback = rrule(Array{Float64}, undef, 5)
9+
@test size(val) == (5, )
10+
@test val isa Array{Float64, 1}
11+
@test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent())
12+
end
13+
@testset "from existing array" begin
14+
test_rrule(Array, randn(2, 5))
15+
test_rrule(Array, Diagonal(randn(5)))
16+
test_rrule(Matrix, Diagonal(randn(5)))
17+
test_rrule(Matrix, transpose(randn(4)))
18+
test_rrule(Array{ComplexF64}, randn(3))
19+
end
20+
end
21+
22+
@testset "vect" begin
23+
test_rrule(Base.vect)
24+
@testset "homogeneous type" begin
25+
test_rrule(Base.vect, (5.0, ), (4.0, ))
26+
test_rrule(Base.vect, 5.0, 4.0, 3.0)
27+
test_rrule(Base.vect, randn(2, 2), randn(3, 3))
28+
end
29+
@testset "inhomogeneous type" begin
30+
test_rrule(
31+
Base.vect, 5.0, 3f0;
32+
atol=1e-6, rtol=1e-6, check_inferred=VERSION>=v"1.6",
33+
) # tolerance due to Float32.
34+
test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false)
35+
end
36+
end
37+
138
@testset "reshape" begin
239
test_rrule(reshape, rand(4, 5), (2, 10))
340
test_rrule(reshape, rand(4, 5), 2, 10)

0 commit comments

Comments
 (0)