Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.0"
version = "1.9.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
5 changes: 3 additions & 2 deletions ext/StaticArraysChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ end

# Project SArray to SArray
function ProjectTo(x::SArray{S, T}) where {S, T}
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = Size(x))
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = axes(x),
size = Size(x))
end

@inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx)

(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.axes, dx)
(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx)

# Adjoint for SArray constructor
function rrule(::Type{T}, x::Tuple) where {T <: SArray}
Expand Down
21 changes: 19 additions & 2 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, LinearAlgebra, Test

@testset "Chain Rules Integration" begin
@testset "ChainRules Integration" begin
@testset "Projection" begin
# There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx)
# implies a check, and reshape will wrap a Vector into a static SizedVector:
pstat = ProjectTo(SA[1, 2, 3])
@test axes(pstat(rand(3))) === (SOneTo(3),)

# This recurses into structured arrays:
pst = ProjectTo(transpose(SA[1, 2, 3]))
@test axes(pst(rand(1,3))) === (SOneTo(1), SOneTo(3))
@test pst(rand(1,3)) isa Transpose

# When the argument is an ordinary Array, static gradients are allowed to pass,
# like FillArrays. Collecting to an Array would cost a copy.
pvec3 = ProjectTo([1, 2, 3])
@test pvec3(SA[1, 2, 3]) isa StaticArray
end

@testset "Constructor rrules" begin
test_rrule(SMatrix{1, 4}, (1.0, 1.0, 1.0, 1.0))
test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0))
test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0))
Expand Down