Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce compile times by unrolling fastbroadcast on array #1546

Merged
merged 4 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 72 additions & 28 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ end
[ode_interpolant!(val,ϕ,integrator,idxs,deriv) for ϕ in Θ]
end

@inline function current_interpolant!(val,t::Array,integrator::DiffEqBase.DEIntegrator,idxs,deriv)
Θ = similar(t)
@inbounds @simd ivdep for i in eachindex(t)
Θ[i] = (t[i]-integrator.tprev)/integrator.dt
end
[ode_interpolant!(val,ϕ,integrator,idxs,deriv) for ϕ in Θ]
end

@inline function current_extrapolant(t::Number,integrator::DiffEqBase.DEIntegrator,idxs=nothing,deriv=Val{0})
Θ = (t-integrator.tprev)/(integrator.t-integrator.tprev)
ode_extrapolant(Θ,integrator,idxs,deriv)
Expand Down Expand Up @@ -475,25 +483,38 @@ end
@inbounds @.. (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
end

@muladd function hermite_interpolant(Θ,dt,y₀::Array,y₁,k,::Type{Val{true}},idxs::Nothing,T::Type{Val{0}}) # Default interpolant is Hermite
out = similar(y₀)
@inbounds @simd ivdep for i in eachindex(y₀)
out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
end

@muladd function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
# return @.. (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
return (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{0}}) # Default interpolant is Hermite
@inbounds @.. out = (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
#@inbounds for i in eachindex(out)
# out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
#end
#out
end

@muladd function hermite_interpolant!(out::Array,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{0}}) # Default interpolant is Hermite
@inbounds @simd ivdep for i in eachindex(out)
out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
@views @.. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
#@inbounds for (j,i) in enumerate(idxs)
# out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
#end
#out
end

@muladd function hermite_interpolant!(out::Array,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
@inbounds for (j,i) in enumerate(idxs)
out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*(Θ-1)*((1-2Θ)*(y₁[i]-y₀[i])+(Θ-1)*dt*k[1][i] + Θ*dt*k[2][i])
end
out
end

"""
Expand All @@ -515,18 +536,24 @@ end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{1}}) # Default interpolant is Hermite
@inbounds @.. out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
#@inbounds for i in eachindex(out)
# out[i] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
#end
#out
end

@muladd function hermite_interpolant!(out::Array,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{1}}) # Default interpolant is Hermite
@inbounds @simd ivdep for i in eachindex(out)
out[i] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps inv(dt) and then multiply by that?

In theory, @fastmath would do that for you, if it weren't for Julia's ordering the LLVM passes in a way that moves the hoisted division back into the loop.

end
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
@views @.. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
#@inbounds for (j,i) in enumerate(idxs)
# out[j] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
#end
#out
end

@muladd function hermite_interpolant!(out::Array,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
@inbounds for (j,i) in enumerate(idxs)
out[j] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
end
out
end

"""
Expand All @@ -551,18 +578,25 @@ end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{2}}) # Default interpolant is Hermite
@inbounds @.. out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
#@inbounds for i in eachindex(out)
# out[i] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
#end
#out
end

@muladd function hermite_interpolant!(out::Array,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{2}}) # Default interpolant is Hermite
@inbounds @simd ivdep for i in eachindex(out)
out[i] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
@views @.. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
#@inbounds for (j,i) in enumerate(idxs)
# out[j] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
#end
#out
end

@muladd function hermite_interpolant!(out::Array,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
@views @.. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
@inbounds for (j,i) in enumerate(idxs)
out[j] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
end
out
end

"""
Expand Down Expand Up @@ -593,12 +627,22 @@ end
#out
end

@muladd function hermite_interpolant!(out::Array,Θ,dt,y₀,y₁,k,idxs::Nothing,T::Type{Val{3}}) # Default interpolant is Hermite
@inbounds @simd ivdep for i in eachindex(out)
out[i] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
out
end

@muladd function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
@views @.. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
#for (j,i) in enumerate(idxs)
# out[j] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
#end
#out
end

@muladd function hermite_interpolant!(out::Array,Θ,dt,y₀,y₁,k,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
@inbounds for (j,i) in enumerate(idxs)
out[j] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
end
out
end

######################## Linear Interpolants
Expand Down
40 changes: 39 additions & 1 deletion src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ function jacobian2W!(W::AbstractMatrix, mass_matrix::MT, dtgamma::Number, J::Abs
idxs = diagind(W)
λ = -mass_matrix.λ
if ArrayInterface.fast_scalar_indexing(J) && ArrayInterface.fast_scalar_indexing(W)
for i in 1:size(J,1)
@inbounds for i in 1:size(J,1)
W[i,i] = muladd(λ, invdtgamma, J[i,i])
end
else
Expand All @@ -451,6 +451,44 @@ function jacobian2W!(W::AbstractMatrix, mass_matrix::MT, dtgamma::Number, J::Abs
return nothing
end

function jacobian2W!(W::Matrix, mass_matrix::MT, dtgamma::Number, J::Matrix, W_transform::Bool)::Nothing where MT
# check size and dimension
iijj = axes(W)
@boundscheck (iijj == axes(J) && length(iijj) == 2) || _throwWJerror(W, J)
mass_matrix isa UniformScaling || @boundscheck axes(mass_matrix) == axes(W) || _throwWMerror(W, mass_matrix)
@inbounds if W_transform
invdtgamma = inv(dtgamma)
if MT <: UniformScaling
copyto!(W, J)
idxs = diagind(W)
λ = -mass_matrix.λ
@inbounds for i in 1:size(J,1)
W[i,i] = muladd(λ, invdtgamma, J[i,i])
end
else
@inbounds for i in eachindex(W)
W[i] = muladd(-mass_matrix[i], invdtgamma, J[i])
end
end
else
if MT <: UniformScaling
idxs = diagind(W)
@inbounds @simd ivdep for i in eachindex(W)
W[i] = dtgamma*J[i]
end
λ = -mass_matrix.λ
@inbounds for i in idxs
W[i] = W[i] + λ
end
else
@inbounds @simd ivdep for i in eachindex(W)
W[i] = muladd(dtgamma, J[i], -mass_matrix[i])
end
end
end
return nothing
end

function jacobian2W(mass_matrix::MT, dtgamma::Number, J::AbstractMatrix, W_transform::Bool)::Nothing where MT
# check size and dimension
mass_matrix isa UniformScaling || @boundscheck axes(mass_matrix) == axes(J) || _throwJMerror(J, mass_matrix)
Expand Down