Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunk size not inferred in forwarddiff handling with static arrays #1533

Closed
ChrisRackauckas opened this issue Dec 7, 2021 · 1 comment
Closed

Comments

@ChrisRackauckas
Copy link
Member

using OrdinaryDiffEq, StaticArrays, BenchmarkTools
function rober2(u,p,t)
  y₁,y₂,y₃ = u
  k₁,k₂,k₃ = p
  du1 = -k₁*y₁+k₃*y₂*y₃
  du2 =  k₁*y₁-k₂*y₂^2-k₃*y₂*y₃
  du3 =  k₂*y₂^2
  SA[du1,du2,du3]
end
prob2 = ODEProblem{false}(rober2,SA[1.0,0.0,0.0],(0.0,1e5),SA[0.04,3e7,1e4])
@btime sol = solve(prob2,Rosenbrock23(),save_everystep=false)

This falls apart in SparseDiffTools

@inline function forwarddiff_color_jacobian(f,
                x::AbstractArray{<:Number};
                colorvec = 1:length(x),
                sparsity = nothing,
                jac_prototype = nothing,
                chunksize = nothing,
                dx = sparsity === nothing && jac_prototype === nothing ? nothing : copy(x)) #if dx is nothing, we will estimate dx at the cost of a function call

    @show typeof(chunksize)
    if sparsity === nothing && jac_prototype === nothing
        cfg = if chunksize === nothing
            if typeof(x) <: StaticArrays.StaticArray
                ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk{StaticArrays.Size(vec(x))[1]}())
            else
                ForwardDiff.JacobianConfig(f, x)
            end
        else
            ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(getsize(chunksize)))
        end
        return ForwardDiff.jacobian(f, x, cfg)
    end
    if dx isa Nothing
        dx = f(x)
    end
    return forwarddiff_color_jacobian(f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity),jac_prototype)
end

because chunksize is now an Int. @chriselrod did you make this Val based?

function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm{0,AD,FDT},
                        OrdinaryDiffEqImplicitAlgorithm{0,AD,FDT},
                        DAEAlgorithm{0,AD,FDT}},u0::AbstractArray,p,prob) where {AD,FDT}
    # If chunksize is zero, pick chunksize right at the start of solve and
    # then do function barrier to infer the full solve
    x = if prob.f.colorvec === nothing
      length(u0)
    else
      maximum(prob.f.colorvec)
    end

    if typeof(alg) <: OrdinaryDiffEqImplicitExtrapolationAlgorithm
      return alg # remake fails, should get fixed
    else
      remake(alg,chunk_size=Val{ForwardDiff.pickchunksize(x)}())
    end
end

fails

@chriselrod
Copy link
Contributor

chriselrod commented Dec 7, 2021

fails

I tried it locally, and it works.
Before:

julia> @btime sol = solve($prob2,Rosenbrock23(),save_everystep=false)
  43.962 μs (509 allocations: 34.34 KiB)
retcode: Success
Interpolation: 1st order linear
t: 2-element Vector{Float64}:
      0.0
 100000.0
u: 2-element Vector{SVector{3, Float64}}:
 [1.0, 0.0, 0.0]
 [0.017827893848396476, 7.258919980845545e-8, 0.9821720335623989]

After:

julia> @btime sol = solve($prob2,Rosenbrock23(),save_everystep=false)
  43.883 μs (505 allocations: 33.75 KiB)
retcode: Success
Interpolation: 1st order linear
t: 2-element Vector{Float64}:
      0.0
 100000.0
u: 2-element Vector{SVector{3, Float64}}:
 [1.0, 0.0, 0.0]
 [0.017827893848396476, 7.258919980845545e-8, 0.9821720335623989]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants