|
| 1 | +@testset "constructors" begin |
| 2 | + |
| 3 | + # We can't use test_rrule here (as it's currently implemented) because the elements of |
| 4 | + # the array have arbitrary values. The only thing we can do is ensure that we're getting |
| 5 | + # `ZeroTangent`s back, and that the forwards pass produces the correct thing still. |
| 6 | + # Issue: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/202 |
| 7 | + @testset "undef" begin |
| 8 | + val, pullback = rrule(Array{Float64}, undef, 5) |
| 9 | + @test size(val) == (5, ) |
| 10 | + @test val isa Array{Float64, 1} |
| 11 | + @test pullback(randn(5)) == (NoTangent(), NoTangent(), NoTangent()) |
| 12 | + end |
| 13 | + @testset "from existing array" begin |
| 14 | + test_rrule(Array, randn(2, 5)) |
| 15 | + test_rrule(Array, Diagonal(randn(5))) |
| 16 | + test_rrule(Matrix, Diagonal(randn(5))) |
| 17 | + test_rrule(Matrix, transpose(randn(4))) |
| 18 | + test_rrule(Array{ComplexF64}, randn(3)) |
| 19 | + end |
| 20 | +end |
| 21 | + |
| 22 | +@testset "vect" begin |
| 23 | + test_rrule(Base.vect) |
| 24 | + @testset "homogeneous type" begin |
| 25 | + test_rrule(Base.vect, (5.0, ), (4.0, )) |
| 26 | + test_rrule(Base.vect, 5.0, 4.0, 3.0) |
| 27 | + test_rrule(Base.vect, randn(2, 2), randn(3, 3)) |
| 28 | + end |
| 29 | + @testset "inhomogeneous type" begin |
| 30 | + test_rrule( |
| 31 | + Base.vect, 5.0, 3f0; |
| 32 | + atol=1e-6, rtol=1e-6, check_inferred=VERSION>=v"1.6", |
| 33 | + ) # tolerance due to Float32. |
| 34 | + test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false) |
| 35 | + end |
| 36 | +end |
| 37 | + |
1 | 38 | @testset "reshape" begin
|
2 | 39 | test_rrule(reshape, rand(4, 5), (2, 10))
|
3 | 40 | test_rrule(reshape, rand(4, 5), 2, 10)
|
|
0 commit comments