diff --git a/ext/StaticArraysChainRulesCoreExt.jl b/ext/StaticArraysChainRulesCoreExt.jl index 5f7904aa..3ae86a5c 100644 --- a/ext/StaticArraysChainRulesCoreExt.jl +++ b/ext/StaticArraysChainRulesCoreExt.jl @@ -15,12 +15,14 @@ end # Project SArray to SArray function ProjectTo(x::SArray{S, T}) where {S, T} - return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = Size(x)) + # We have a axes field because it is expected by other ProjectTo's like the one for Transpose + return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = axes(x), + size = Size(x)) end @inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx) -(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.axes, dx) +(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx) # Adjoint for SArray constructor function rrule(::Type{T}, x::Tuple) where {T <: SArray} diff --git a/test/chainrules.jl b/test/chainrules.jl index 7dbf8e8a..c586432a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,7 +1,24 @@ -using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test +using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, LinearAlgebra, Test -@testset "Chain Rules Integration" begin +@testset "ChainRules Integration" begin @testset "Projection" begin + # There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx) + # implies a check, and reshape will wrap a Vector into a static SizedVector: + pstat = ProjectTo(SA[1, 2, 3]) + @test axes(pstat(rand(3))) === (SOneTo(3),) + + # This recurses into structured arrays: + pst = ProjectTo(transpose(SA[1, 2, 3])) + @test axes(pst(rand(1,3))) === (SOneTo(1), SOneTo(3)) + @test pst(rand(1,3)) isa Transpose + + # When the argument is an ordinary Array, static gradients are allowed to pass, + # like FillArrays. Collecting to an Array would cost a copy. + pvec3 = ProjectTo([1, 2, 3]) + @test pvec3(SA[1, 2, 3]) isa StaticArray + end + + @testset "Constructor rrules" begin test_rrule(SMatrix{1, 4}, (1.0, 1.0, 1.0, 1.0)) test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0)) test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0))