diff --git a/Project.toml b/Project.toml index d46130e3..b430a5fa 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] @@ -24,6 +25,7 @@ SparseConnectivityTracerLogExpFunctionsExt = "LogExpFunctions" SparseConnectivityTracerNNlibExt = "NNlib" SparseConnectivityTracerNaNMathExt = "NaNMath" SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions" +SparseConnectivityTracerRecursiveArrayToolsExt = "RecursiveArrayTools" [compat] ADTypes = "1" @@ -35,6 +37,7 @@ LogExpFunctions = "0.3.28" NNlib = "0.8, 0.9" NaNMath = "1" Random = "<0.0.1, 1" +RecursiveArrayTools = "3.31.2" SparseArrays = "<0.0.1, 1" SpecialFunctions = "2.4" julia = "1.10" diff --git a/ext/SparseConnectivityTracerRecursiveArrayToolsExt.jl b/ext/SparseConnectivityTracerRecursiveArrayToolsExt.jl new file mode 100644 index 00000000..b11200d3 --- /dev/null +++ b/ext/SparseConnectivityTracerRecursiveArrayToolsExt.jl @@ -0,0 +1,29 @@ +module SparseConnectivityTracerRecursiveArrayToolsExt + +import SparseConnectivityTracer as SCT +using RecursiveArrayTools: ArrayPartition, NamedArrayPartition + +function SCT.trace_input( + ::Type{T}, xs::ArrayPartition, i +) where {T<:Union{SCT.AbstractTracer,SCT.Dual}} + ts = SCT.create_tracers(T, xs, eachindex(xs)) + lengths = map(length, xs.x) + length_sums = (0, cumsum(lengths)...) + return ArrayPartition( + Tuple( + reshape(view(ts, (1 + length_sums[j]):(length_sums[j + 1])), size(xs.x[j])) for + j in eachindex(xs.x) + ), + ) +end + +function SCT.trace_input( + ::Type{T}, xs::NamedArrayPartition, i +) where {T<:Union{SCT.AbstractTracer,SCT.Dual}} + return NamedArrayPartition( + SCT.trace_input(T, getfield(xs, :array_partition), i), + getfield(xs, :names_to_indices), + ) +end + +end # module SparseConnectivityTracerRecursiveArrayToolsExt \ No newline at end of file diff --git a/test/recursive_array_tools.jl b/test/recursive_array_tools.jl new file mode 100644 index 00000000..575f67a5 --- /dev/null +++ b/test/recursive_array_tools.jl @@ -0,0 +1,12 @@ +using RecursiveArrayTools + +function f!(du, u) + du.foo[2] = u.foo[1] + du.foo[1] = u.foo[2] + return nothing +end + +u = NamedArrayPartition(; foo=rand(5)) +du = copy(u) + +jacobian_sparsity(f!, du, u, TracerSparsityDetector()) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 981d49e8..c1d20075 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -99,6 +99,9 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") include("flux.jl") end end + @testset "RecursiveArrayTools.jl" begin + include("recursive_array_tools.jl") + end end end