From bb07a0499307e4bc94dab680a14035ee8c636537 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Apr 2024 11:55:37 +0530 Subject: [PATCH] refactor: add DiffEqArray constructor --- Project.toml | 3 ++- src/LabelledArrays.jl | 1 + src/diffeqarray.jl | 7 +++++++ test/recursivearraytools.jl | 6 ++++++ test/runtests.jl | 32 +++++++++++++++++++++----------- 5 files changed, 37 insertions(+), 12 deletions(-) create mode 100644 src/diffeqarray.jl create mode 100644 test/recursivearraytools.jl diff --git a/Project.toml b/Project.toml index fe186e4..3c47fb0 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ ChainRulesCore = "1" ForwardDiff = "0.10.3" MacroTools = "0.5" PreallocationTools = "0.4" -RecursiveArrayTools = "2,3" +RecursiveArrayTools = "3" StaticArrays = "1.0" julia = "1.6" @@ -27,6 +27,7 @@ julia = "1.6" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] diff --git a/src/LabelledArrays.jl b/src/LabelledArrays.jl index 98fdc58..00d1de2 100644 --- a/src/LabelledArrays.jl +++ b/src/LabelledArrays.jl @@ -6,6 +6,7 @@ import RecursiveArrayTools, PreallocationTools, ForwardDiff include("slarray.jl") include("larray.jl") include("chainrules.jl") +include("diffeqarray.jl") # Common @generated function __getindex(x::Union{LArray, SLArray}, ::Val{s}) where {s} diff --git a/src/diffeqarray.jl b/src/diffeqarray.jl new file mode 100644 index 0000000..11f59c6 --- /dev/null +++ b/src/diffeqarray.jl @@ -0,0 +1,7 @@ +for LArrayType in [LArray, SLArray] + @eval function RecursiveArrayTools.DiffEqArray(vec::AbstractVector{<:$LArrayType}, + ts::AbstractVector, + p = nothing) + RecursiveArrayTools.DiffEqArray(vec, ts, p; variables = collect(symbols(vec[1]))) + end +end diff --git a/test/recursivearraytools.jl b/test/recursivearraytools.jl new file mode 100644 index 0000000..3193d55 --- /dev/null +++ b/test/recursivearraytools.jl @@ -0,0 +1,6 @@ +using RecursiveArrayTools, LabelledArrays, Test + +ABC = @SLVector (:a, :b, :c); +A = ABC(1, 2, 3); +B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]); +@test getindex(B, :a) == [1, 1] diff --git a/test/runtests.jl b/test/runtests.jl index abe2662..4d2e0a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,17 +4,27 @@ using StaticArrays using InteractiveUtils using ChainRulesTestUtils -@time begin - @time @testset "SLArrays" begin - include("slarrays.jl") - end - @time @testset "LArrays" begin - include("larrays.jl") - end - @time @testset "DiffEq" begin - include("diffeq.jl") +const GROUP = get(ENV, "GROUP", "All") + +if GROUP == "All" + @time begin + @time @testset "SLArrays" begin + include("slarrays.jl") + end + @time @testset "LArrays" begin + include("larrays.jl") + end + @time @testset "DiffEq" begin + include("diffeq.jl") + end + @time @testset "ChainRules" begin + include("chainrules.jl") + end end - @time @testset "ChainRules" begin - include("chainrules.jl") +end + +if GROUP == "All" || GROUP == "RecursiveArrayTools" + @time @testset "RecursiveArrayTools" begin + include("recursivearraytools.jl") end end