diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index 94a3a09419..520c3fcc38 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -413,10 +413,22 @@ function ode_interpolant(Θ,dt,y₀,y₁,k,cache::OrdinaryDiffEqMutableCache,idx # typeof(y₀) can be these if saveidxs gives a single value _ode_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T) elseif typeof(idxs) <: Nothing - out = oneunit(Θ) .* y₁ + if y₁ isa Array{<:Number} + out = similar(y₁,eltype(first(y₁)*oneunit(Θ))) + copyto!(out,y₁) + else + out = oneunit(Θ) .* y₁ + end _ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T) else - out = oneunit(Θ) .* y₁[idxs] + if y₁ isa Array{<:Number} + out = similar(y₁,eltype(first(y₁)*oneunit(Θ)),axes(idxs)) + for i in eachindex(idxs) + out[i] = y₁[idxs[i]] + end + else + out = oneunit(Θ) .* y₁[idxs] + end _ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T) end end diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index a7912a91e4..c790678044 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -31,7 +31,9 @@ function calc_tderivative!(integrator::ODEIntegrator{algType,IIP,<:Array}, cache f.tgrad(dT, uprev, p, t) else tf.uprev = uprev - tf.p = p + if !(p isa DiffEqBase.NullParameters) + tf.p = p + end derivative!(dT, tf, t, du2, integrator, cache.grad_config) end end @@ -133,7 +135,9 @@ function calc_J!(J, integrator, cache) uf.f = nlsolve_f(f, alg) uf.t = t - uf.p = p + if !(p isa DiffEqBase.NullParameters) + uf.p = p + end jacobian!(J, uf, uprev, du1, integrator, jac_config) end @@ -594,7 +598,7 @@ function update_W!(nlsolver::AbstractNLSolver, integrator, cache, dtgamma, repea nothing end -function build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,::Val{IIP}) where IIP +function build_J_W(alg,u,uprev,p,t,dt,f::F,uEltypeNoUnits,::Val{IIP}) where {IIP,F} islin, isode = islinearfunction(f, alg) if f.jac_prototype isa DiffEqBase.AbstractDiffEqLinearOperator W = WOperator{IIP}(f, u, dt) diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index a84a612efb..a3173465c2 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -169,11 +169,16 @@ function resize_grad_config!(grad_config::FiniteDiff.GradientCache, i) grad_config end -function build_grad_config(alg,f,tf::F,du1,t) where {F} +function build_grad_config(alg,f::F1,tf::F2,du1,t) where {F1,F2} if !DiffEqBase.has_tgrad(f) if alg_autodiff(alg) dualt = Dual{typeof(ForwardDiff.Tag(tf,eltype(t)))}(t, t) - grad_config = ArrayInterface.restructure(du1,du1 .* dualt) + if du1 isa Array + grad_config = similar(du1,eltype(first(du1)*dualt)) + fill!(grad_config,false) + else + grad_config = ArrayInterface.restructure(du1,du1 .* dualt) + end else grad_config = FiniteDiff.GradientCache(du1,t,alg_difftype(alg)) end diff --git a/src/initdt.jl b/src/initdt.jl index 81f165c8b9..40f6025c3f 100644 --- a/src/initdt.jl +++ b/src/initdt.jl @@ -23,7 +23,14 @@ @.. sk = abstol+internalnorm(u0,t)*reltol end else - sk = @.. abstol+internalnorm(u0,t)*reltol + if u0 isa Array && abstol isa Number && reltol isa Number + sk = similar(u0,typeof(internalnorm(first(u0),t)*reltol)) + @inbounds @simd ivdep for i in eachindex(u0) + sk[i] = abstol+internalnorm(u0[i],t)*reltol + end + else + sk = @.. abstol+internalnorm(u0,t)*reltol + end end if get_current_isfsal(integrator.alg, integrator.cache) && typeof(integrator) <: ODEIntegrator @@ -32,7 +39,13 @@ f(f₀,u0,p,t) else # TODO: use more caches - f₀ = zero.(u0./t) + if u0 isa Array + T = eltype(first(u0)/t) + f₀ = similar(u0,T) + fill!(f₀,zero(T)) + else + f₀ = zero.(u0./t) + end f(f₀,u0,p,t) end