Skip to content

Commit

Permalink
Use extensions for weakdeps (#464)
Browse files Browse the repository at this point in the history
* Use extensions for weakdeps

* LSP added "using Base: get_extension", which is probably a bad idea

* Fix extension syntax

* Fixes

* module

* Fix zygote tests
  • Loading branch information
chriselrod committed Jan 29, 2023
1 parent 21b86bb commit 807675d
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 54 deletions.
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
name = "LoopVectorization"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
authors = ["Chris Elrod <[email protected]>"]
version = "0.12.148"
version = "0.12.149"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"]
SpecialFunctionsExt = "SpecialFunctions"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
79 changes: 68 additions & 11 deletions src/simdfunctionals/vmap_grad_rrule.jl → ext/ForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,60 @@
module ForwardDiffExt
import ForwardDiff, ChainRulesCore
using SIMDDualNumbers, LoopVectorization
using LoopVectorization:
AbstractSIMD,
AbstractStridedPointer,
relu,
vmap,
VectorizationBase,
vmapt,
vmapnt,
vmapntt,
MM,
StaticInt,
vadd_nw,
vsub_nsw,
vload,
mask,
vfnmadd_fast,
mul_fast
using VectorizationBase: zero_offsets

import .ChainRulesCore
@generated function init_dual(v::Tuple{Vararg{AbstractSIMD,A}}) where {A}
res = Expr(:tuple)
q = Expr(:block, Expr(:meta, :inline))
for a 1:A
v_a = Symbol(:v_, a)
push!(q.args, Expr(:(=), v_a, Expr(:ref, :v, a)))
partials = Expr(:tuple)
for i 1:A
push!(partials.args, Expr(:call, i == a ? :one : :zero, v_a))
end
push!(res.args, :(ForwardDiff.Dual($v_a, ForwardDiff.Partials($partials))))
end
push!(q.args, res)
q
end
@generated function dual_store!(
∂p::Tuple{Vararg{AbstractStridedPointer,A}},
p::AbstractStridedPointer,
∂v,
im::Vararg{Any,N}
) where {A,N}
quote
$(Expr(:meta, :inline))
v = ∂v.value
= ∂v.partials
Base.Cartesian.@nextract $N im im
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! p v im # store
Base.Cartesian.@nexprs $A a -> begin # for each of `A` partials
∂p_a = ∂p[a]
∂_a = ∂[a]
Base.Cartesian.@ncall $N VectorizationBase.vnoaliasstore! ∂p_a ∂_a im # store
end
nothing
end
end

if isdefined(ChainRulesCore, :ZeroTangent)
const ChainRulesZero = ChainRulesCore.ZeroTangent
Expand Down Expand Up @@ -38,32 +93,33 @@ function ∂vmap_singlethread!(
args::Vararg{DenseArray{<:Base.HWReal},A}
) where {F,T<:Base.HWReal,A}
N = length(y)
ptry = VectorizationBase.zero_offsets(stridedpointer(y))
ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args))
ptr∂y = VectorizationBase.zero_offsets.(stridedpointer.(∂y))

ptry = zero_offsets(stridedpointer(y))
ptrargs = map(zero_offsets, map(stridedpointer, args))
ptr∂y = map(zero_offsets, map(stridedpointer, ∂y))
i = 0
V = VectorizationBase.pick_vector_width(T)
W = Int(V)
st = VectorizationBase.static_sizeof(T)
zero_index = MM{W}(StaticInt(0), st)
while i < vsub_nsw(N, ((W << 2) - 1))
index = VectorizationBase.Unroll{1,W,4,1,W,zero(UInt)}((i,))
v = f(init_dual(vload.(ptrargs, index))...)
v = f(init_dual(map(Base.Fix2(vload, index), ptrargs))...)
dual_store!(ptr∂y, ptry, v, index)
i = vadd_nw(i, 4W)
end
while i < vsub_nsw(N, (W - 1))
vᵣ = f(init_dual(vload.(ptrargs, ((MM{W}(i),),)))...)
loader = Base.Fix2(vload, (MM{W}(i),))
vᵣ = f(init_dual(map(loader, ptrargs))...)
dual_store!(ptr∂y, ptry, vᵣ, (MM{W}(i),))
i = vadd_nw(i, W)
end
if i < N
m = mask(T, N & (W - 1))
mloader = let i = i, m = m
p -> vload(p, (MM{W}(i),), m)
end
dual_store!(
ptr∂y,
ptry,
f(init_dual(vload.(ptrargs, ((MM{W}(i),),), m))...),
f(init_dual(map(mloader, ptrargs))...),
(MM{W}(i),),
m
)
Expand Down Expand Up @@ -109,6 +165,7 @@ for f in (:vmapt, :vmapnt, :vmapntt)
f::F,
args::Vararg{Any,K}
) where {F,K}
ChainRulesCore.rrule(typeof(vmap), f, args...)
ChainRulesCore.rrule(typeof($vmap), f, args...)
end
end
end
6 changes: 6 additions & 0 deletions ext/SpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module SpecialFunctionsExt
using SpecialFunctions
using LoopVectorization: VectorizationBase
using LoopVectorization: AbstractSIMD
@inline SpecialFunctions.erf(x::AbstractSIMD) = VectorizationBase.verf(float(x))
end
8 changes: 4 additions & 4 deletions src/LoopVectorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ include("precompile.jl")

# import ChainRulesCore, ForwardDiff
# include("vmap_grad.jl")
using ChainRulesCore, ForwardDiff, SpecialFunctions
include("simdfunctionals/vmap_grad_rrule.jl")
include("simdfunctionals/vmap_grad_forwarddiff.jl")
@inline SpecialFunctions.erf(x::AbstractSIMD) = VectorizationBase.verf(float(x))
if !isdefined(Base, :get_extension)
include("../ext/ForwardDiffExt.jl")
include("../ext/SpecialFunctionsExt.jl")
end

end # module
38 changes: 0 additions & 38 deletions src/simdfunctionals/vmap_grad_forwarddiff.jl

This file was deleted.

2 comments on commit 807675d

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/76621

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.149 -m "<description of version>" 807675dc8ed40783891df1af790527f70102cfbb
git push origin v0.12.149

Please sign in to comment.