Skip to content

Commit

Permalink
Drop StaticArrays by using StaticArraysCore
Browse files Browse the repository at this point in the history
Requires JuliaArrays/ArrayInterface.jl#325

```julia
@time_imports using SciMLBase
    10.4 ms    ┌ MacroTools
     19.0 ms  ┌ ZygoteRules
      3.8 ms  ┌ Compat
      1.5 ms    ┌ Adapt
      3.7 ms    ┌ ArrayInterfaceCore
      2.0 ms    ┌ StaticArraysCore
      9.7 ms  ┌ ArrayInterfaceStaticArraysCore
    123.1 ms  ┌ FillArrays
      5.0 ms  ┌ DocStringExtensions
     18.2 ms  ┌ RecipesBase
     51.3 ms  ┌ ChainRulesCore
      4.0 ms  ┌ GPUArraysCore
    292.7 ms  RecursiveArrayTools
```

Compare that to #213
  • Loading branch information
ChrisRackauckas committed Jun 25, 2022
1 parent 945a9d1 commit 3624839
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@ version = "2.30.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
ArrayInterfaceStaticArrays = "b0d46f97-bff5-4637-a19a-dd75974142cd"
ArrayInterfaceStaticArraysCore = "dd5226c6-a4d4-4bc7-8575-46859f9c95b9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Adapt = "3"
ArrayInterfaceCore = "0.1.1"
ArrayInterfaceStaticArrays = "0.1"
ArrayInterfaceStaticArraysCore = "0.1"
ChainRulesCore = "0.10.7, 1"
DocStringExtensions = "0.8, 0.9"
FillArrays = "0.11, 0.12, 0.13"
GPUArraysCore = "0.1"
RecipesBase = "0.7, 0.8, 1.0"
StaticArrays = "0.12, 1.0"
StaticArraysCore = "1"
ZygoteRules = "0.2"
julia = "1.6"

Expand All @@ -36,10 +36,11 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StructArrays", "Zygote"]
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"]
4 changes: 2 additions & 2 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ $(DocStringExtensions.README)
module RecursiveArrayTools

using DocStringExtensions
using RecipesBase, StaticArrays, Statistics,
using RecipesBase, StaticArraysCore, Statistics,
ArrayInterfaceCore, LinearAlgebra

import ChainRulesCore
Expand All @@ -15,7 +15,7 @@ import ZygoteRules, Adapt
# Required for the downstream_events.jl test
# Since `ismutable` on an ArrayPartition needs
# to know static arrays are not mutable
import ArrayInterfaceStaticArrays
import ArrayInterfaceStaticArraysCore

using FillArrays

Expand Down
15 changes: 8 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ like `copy` on arrays of scalars.
function recursivecopy(a)
deepcopy(a)
end
recursivecopy(a::Union{SVector,SMatrix,SArray,Number}) = copy(a)
recursivecopy(a::Union{StaticArraysCore.SVector,StaticArraysCore.SMatrix,
StaticArraysCore.SArray,Number}) = copy(a)
function recursivecopy(a::AbstractArray{T,N}) where {T<:Number,N}
copy(a)
end
Expand All @@ -33,7 +34,7 @@ like `copy!` on arrays of scalars.
"""
function recursivecopy! end

function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:StaticArray,T2<:StaticArray,N}
function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:StaticArraysCore.StaticArray,T2<:StaticArraysCore.StaticArray,N}
@inbounds for i in eachindex(a)
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
b[i] = copy(a[i])
Expand Down Expand Up @@ -68,13 +69,13 @@ A recursive `fill!` function.
"""
function recursivefill! end

function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArray,T2<:StaticArray,N}
function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArraysCore.StaticArray,T2<:StaticArraysCore.StaticArray,N}
@inbounds for i in eachindex(b)
b[i] = copy(a)
end
end

function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:SArray,T2<:Union{Number,Bool},N}
function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArraysCore.SArray,T2<:Union{Number,Bool},N}
@inbounds for i in eachindex(b)
b[i] = fill(a, typeof(b[i]))
end
Expand All @@ -88,7 +89,7 @@ function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:Union{Number,Bool
fill!(b, a)
end

function recursivefill!(b::AbstractArray{T,N},a) where {T<:MArray,N}
function recursivefill!(b::AbstractArray{T,N},a) where {T<:StaticArraysCore.MArray,N}
@inbounds for i in eachindex(b)
if isassigned(b,i)
recursivefill!(b[i],a)
Expand Down Expand Up @@ -151,7 +152,7 @@ If `i<length(x)`, it's simply a `recursivecopy!` to the `i`th element. Otherwise
function copyat_or_push!(a::AbstractVector{T},i::Int,x,nc::Type{Val{perform_copy}}=Val{true}) where {T,perform_copy}
@inbounds if length(a) >= i
if !ArrayInterfaceCore.ismutable(T) || !perform_copy
# TODO: Check for `setindex!`` if T <: StaticArray and use `copy!(b[i],a[i])`
# TODO: Check for `setindex!`` if T <: StaticArraysCore.StaticArray and use `copy!(b[i],a[i])`
# or `b[i] = a[i]`, see https://github.com/JuliaDiffEq/RecursiveArrayTools.jl/issues/19
a[i] = x
else
Expand Down Expand Up @@ -208,7 +209,7 @@ ones has a `Array{Array{Float64,N},N}`, this will return `Array{Float64,N}`.
"""
recursive_unitless_eltype(a) = recursive_unitless_eltype(eltype(a))
recursive_unitless_eltype(a::Type{Any}) = Any
recursive_unitless_eltype(a::Type{T}) where {T<:StaticArray} = similar_type(a,recursive_unitless_eltype(eltype(a)))
recursive_unitless_eltype(a::Type{T}) where {T<:StaticArraysCore.StaticArray} = similar_type(a,recursive_unitless_eltype(eltype(a)))
recursive_unitless_eltype(a::Type{T}) where {T<:Array} = Array{recursive_unitless_eltype(eltype(a)),ndims(a)}
recursive_unitless_eltype(a::Type{T}) where {T<:Number} = typeof(one(eltype(a)))
recursive_unitless_eltype(::Type{<:Enum{T}}) where T = T
Expand Down

0 comments on commit 3624839

Please sign in to comment.