Skip to content

Commit

Permalink
Merge pull request #1469 from SciML/inference3
Browse files Browse the repository at this point in the history
More compile time fixes
  • Loading branch information
ChrisRackauckas authored Aug 12, 2021
2 parents 2dc3bd7 + f329b8e commit b03c75d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 9 deletions.
16 changes: 14 additions & 2 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,22 @@ function ode_interpolant(Θ,dt,y₀,y₁,k,cache::OrdinaryDiffEqMutableCache,idx
# typeof(y₀) can be these if saveidxs gives a single value
_ode_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T)
elseif typeof(idxs) <: Nothing
out = oneunit(Θ) .* y₁
if y₁ isa Array{<:Number}
out = similar(y₁,eltype(first(y₁)*oneunit(Θ)))
copyto!(out,y₁)
else
out = oneunit(Θ) .* y₁
end
_ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T)
else
out = oneunit(Θ) .* y₁[idxs]
if y₁ isa Array{<:Number}
out = similar(y₁,eltype(first(y₁)*oneunit(Θ)),axes(idxs))
for i in eachindex(idxs)
out[i] = y₁[idxs[i]]
end
else
out = oneunit(Θ) .* y₁[idxs]
end
_ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T)
end
end
Expand Down
10 changes: 7 additions & 3 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ function calc_tderivative!(integrator::ODEIntegrator{algType,IIP,<:Array}, cache
f.tgrad(dT, uprev, p, t)
else
tf.uprev = uprev
tf.p = p
if !(p isa DiffEqBase.NullParameters)
tf.p = p
end
derivative!(dT, tf, t, du2, integrator, cache.grad_config)
end
end
Expand Down Expand Up @@ -133,7 +135,9 @@ function calc_J!(J, integrator, cache)

uf.f = nlsolve_f(f, alg)
uf.t = t
uf.p = p
if !(p isa DiffEqBase.NullParameters)
uf.p = p
end

jacobian!(J, uf, uprev, du1, integrator, jac_config)
end
Expand Down Expand Up @@ -594,7 +598,7 @@ function update_W!(nlsolver::AbstractNLSolver, integrator, cache, dtgamma, repea
nothing
end

function build_J_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits,::Val{IIP}) where IIP
function build_J_W(alg,u,uprev,p,t,dt,f::F,uEltypeNoUnits,::Val{IIP}) where {IIP,F}
islin, isode = islinearfunction(f, alg)
if f.jac_prototype isa DiffEqBase.AbstractDiffEqLinearOperator
W = WOperator{IIP}(f, u, dt)
Expand Down
9 changes: 7 additions & 2 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,16 @@ function resize_grad_config!(grad_config::FiniteDiff.GradientCache, i)
grad_config
end

function build_grad_config(alg,f,tf::F,du1,t) where {F}
function build_grad_config(alg,f::F1,tf::F2,du1,t) where {F1,F2}
if !DiffEqBase.has_tgrad(f)
if alg_autodiff(alg)
dualt = Dual{typeof(ForwardDiff.Tag(tf,eltype(t)))}(t, t)
grad_config = ArrayInterface.restructure(du1,du1 .* dualt)
if du1 isa Array
grad_config = similar(du1,eltype(first(du1)*dualt))
fill!(grad_config,false)
else
grad_config = ArrayInterface.restructure(du1,du1 .* dualt)
end
else
grad_config = FiniteDiff.GradientCache(du1,t,alg_difftype(alg))
end
Expand Down
17 changes: 15 additions & 2 deletions src/initdt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
@.. sk = abstol+internalnorm(u0,t)*reltol
end
else
sk = @.. abstol+internalnorm(u0,t)*reltol
if u0 isa Array && abstol isa Number && reltol isa Number
sk = similar(u0,typeof(internalnorm(first(u0),t)*reltol))
@inbounds @simd ivdep for i in eachindex(u0)
sk[i] = abstol+internalnorm(u0[i],t)*reltol
end
else
sk = @.. abstol+internalnorm(u0,t)*reltol
end
end

if get_current_isfsal(integrator.alg, integrator.cache) && typeof(integrator) <: ODEIntegrator
Expand All @@ -32,7 +39,13 @@
f(f₀,u0,p,t)
else
# TODO: use more caches
f₀ = zero.(u0./t)
if u0 isa Array
T = eltype(first(u0)/t)
f₀ = similar(u0,T)
fill!(f₀,zero(T))
else
f₀ = zero.(u0./t)
end
f(f₀,u0,p,t)
end

Expand Down

0 comments on commit b03c75d

Please sign in to comment.