Skip to content

Commit

Permalink
Fix and test leakyrelu (#505)
Browse files Browse the repository at this point in the history
* ignore oftype

* broadcast fixes

* delete problematic line that accidentally wasn't removed

* no print

* fix order

* ifelse unroll

* Test leakyrelu
  • Loading branch information
chriselrod committed Jul 20, 2023
1 parent a21d6f8 commit 73fb725
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LoopVectorization"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
authors = ["Chris Elrod <[email protected]>"]
version = "0.12.163"
version = "0.12.164"


[deps]
Expand Down
52 changes: 26 additions & 26 deletions benchmark/looptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ function jgemm!(𝐂, 𝐀ᵀ::Adjoint, 𝐁ᵀ::Adjoint)
end
end
gemmavx!(𝐂, 𝐀, 𝐁) = @turbo for m indices((𝐀, 𝐂), 1), n indices((𝐁, 𝐂), 2)
𝐂ₘₙ = zero(eltype(𝐂))
for k indices((𝐀, 𝐁), (2, 1))
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
end
𝐂[m, n] = 𝐂ₘₙ
𝐂ₘₙ = zero(eltype(𝐂))
for k indices((𝐀, 𝐁), (2, 1))
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
end
𝐂[m, n] = 𝐂ₘₙ
end
function gemmavx!(
Cc::AbstractMatrix{Complex{T}},
Ac::AbstractMatrix{Complex{T}},
Expand All @@ -102,12 +102,12 @@ function gemmavx!(
end
end
gemmavxt!(𝐂, 𝐀, 𝐁) = @tturbo for m indices((𝐀, 𝐂), 1), n indices((𝐁, 𝐂), 2)
𝐂ₘₙ = zero(eltype(𝐂))
for k indices((𝐀, 𝐁), (2, 1))
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
end
𝐂[m, n] = 𝐂ₘₙ
𝐂ₘₙ = zero(eltype(𝐂))
for k indices((𝐀, 𝐁), (2, 1))
𝐂ₘₙ += 𝐀[m, k] * 𝐁[k, n]
end
𝐂[m, n] = 𝐂ₘₙ
end
function gemmavxt!(
Cc::AbstractMatrix{Complex{T}},
Ac::AbstractMatrix{Complex{T}},
Expand Down Expand Up @@ -204,11 +204,11 @@ function jdot3avx(x, A, y)
s
end
jvexp!(b, a) = @inbounds for i eachindex(a)
b[i] = exp(a[i])
end
b[i] = exp(a[i])
end
jvexpavx!(b, a) = @turbo for i eachindex(a)
b[i] = exp(a[i])
end
b[i] = exp(a[i])
end
function jsvexp(a)
s = zero(eltype(a))
@inbounds for i eachindex(a)
Expand Down Expand Up @@ -242,12 +242,12 @@ function jgemv!(𝐲, 𝐀ᵀ::Adjoint, 𝐱)
end
end
jgemvavx!(𝐲, 𝐀, 𝐱) = @turbo for i eachindex(𝐲)
𝐲ᵢ = zero(eltype(𝐲))
for j eachindex(𝐱)
𝐲ᵢ += 𝐀[i, j] * 𝐱[j]
end
𝐲[i] = 𝐲ᵢ
𝐲ᵢ = zero(eltype(𝐲))
for j eachindex(𝐱)
𝐲ᵢ += 𝐀[i, j] * 𝐱[j]
end
𝐲[i] = 𝐲ᵢ
end
function jvar!(𝐬², 𝐀, x̄)
@.= zero(eltype(𝐬²))
@inbounds @fastmath for i 1:size(𝐀, 2)
Expand All @@ -258,14 +258,14 @@ function jvar!(𝐬², 𝐀, x̄)
end
end
jvaravx!(𝐬², 𝐀, x̄) = @turbo for j eachindex(𝐬²)
𝐬²ⱼ = zero(eltype(𝐬²))
x̄ⱼ = x̄[j]
for i 1:size(𝐀, 2)
δ = 𝐀[j, i] - x̄ⱼ
𝐬²ⱼ += δ * δ
end
𝐬²[j] = 𝐬²ⱼ
𝐬²ⱼ = zero(eltype(𝐬²))
x̄ⱼ = x̄[j]
for i 1:size(𝐀, 2)
δ = 𝐀[j, i] - x̄ⱼ
𝐬²ⱼ += δ * δ
end
𝐬²[j] = 𝐬²ⱼ
end
japlucBc!(D, a, B, c) = @. D = a + B * c';
japlucBcavx!(D, a, B, c) = @turbo @. D = a + B * c';

Expand Down
3 changes: 2 additions & 1 deletion benchmark/plotbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ else
# const COLOR_MAP = Dict{String,RGB{Float64}}()
# const COLOR_MAP = Dict{String,RGB{Colors.N0f8}}()
const COLOR_MAP64 = Dict{String,RGB{Float64}}()
getcolor(s::String) = get!(COLOR_MAP64, s) do
getcolor(s::String) =
get!(COLOR_MAP64, s) do
COLORS[length(COLOR_MAP64)+1]
end
replace_and(str) = replace(str, '&' => "with")
Expand Down
35 changes: 29 additions & 6 deletions ext/ForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ end
end
end

@generated function ifelse(
m::AbstractMask,
@generated function _ifelse(
m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}},
x::ForwardDiff.Dual{TAG,V,P},
y::ForwardDiff.Dual{TAG,V,P}
) where {TAG,V,P}
Expand All @@ -171,8 +171,8 @@ end
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
end
end
@generated function ifelse(
m::AbstractMask,
@generated function _ifelse(
m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}},
x::Number,
y::ForwardDiff.Dual{TAG,V,P}
) where {TAG,V,P}
Expand All @@ -184,8 +184,8 @@ end
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
end
end
@generated function ifelse(
m::AbstractMask,
@generated function _ifelse(
m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}},
x::ForwardDiff.Dual{TAG,V,P},
y::Number
) where {TAG,V,P}
Expand All @@ -197,6 +197,29 @@ end
ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
end
end
@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::Number) =
_ifelse(m, x, y)
@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::ForwardDiff.Dual) =
_ifelse(m, x, y)
@inline ifelse(m::AbstractMask, y::Number, x::ForwardDiff.Dual) =
_ifelse(m, y, x)

@inline ifelse(
m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask},
x::ForwardDiff.Dual,
y::Number
) = _ifelse(m, x, y)
@inline ifelse(
m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask},
x::ForwardDiff.Dual,
y::ForwardDiff.Dual
) = _ifelse(m, x, y)
@inline ifelse(
m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask},
y::Number,
x::ForwardDiff.Dual
) = _ifelse(m, y, x)

@inline function SLEEFPirates.softplus(x::ForwardDiff.Dual{TAG}) where {TAG}
val = ForwardDiff.value(x)
expx = exp(val)
Expand Down
3 changes: 2 additions & 1 deletion src/LoopVectorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ using VectorizationBase:
contract_or,
collapse_or,
max_mask,
maybestaticsize,zero_mask
maybestaticsize,
zero_mask

using HostCPUFeatures:
pick_vector_width,
Expand Down
3 changes: 2 additions & 1 deletion src/codegen/split_loops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ function add_operation!(
opnew
end

append_if_included!(vnew, vold, included) = for (i, v) vold
append_if_included!(vnew, vold, included) =
for (i, v) vold
id = included[i]
iszero(id) || push!(vnew, (id, v))
end
Expand Down
12 changes: 8 additions & 4 deletions src/modeling/costs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ struct Instruction
end
# lower(instr::Instruction) = Expr(:(.), instr.mod, QuoteNode(instr.instr))
# Base.convert(::Type{Expr}, instr::Instruction) = Expr(:(.), instr.mod, QuoteNode(instr.instr))
callexpr(instr::Instruction) = if instr.mod === :LoopVectorization
callexpr(instr::Instruction) =
if instr.mod === :LoopVectorization
Expr(:call, lv(instr.instr))
else#if instr.mod === :Main
Expr(:call, instr.instr)
Expand Down Expand Up @@ -563,7 +564,8 @@ function reduction_to_single_vector(x::Float64)
throw("Reduction not found.")
end
end
reduce_to_onevecunroll(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
reduce_to_onevecunroll(x::Float64) =
if x == ADDITIVE_IN_REDUCTIONS
:+
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
:*
Expand All @@ -578,7 +580,8 @@ reduce_to_onevecunroll(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
else
throw("Reduction not found.")
end
reduce_number_of_vectors(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
reduce_number_of_vectors(x::Float64) =
if x == ADDITIVE_IN_REDUCTIONS
:contract_add
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
:contract_mul
Expand All @@ -593,7 +596,8 @@ reduce_number_of_vectors(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
else
throw("Reduction not found.")
end
reduction_to_scalar(x::Float64) = if x == ADDITIVE_IN_REDUCTIONS
reduction_to_scalar(x::Float64) =
if x == ADDITIVE_IN_REDUCTIONS
:vsum
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
:vprod
Expand Down
2 changes: 1 addition & 1 deletion src/predicates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ isscopedname(:(Base.Checked.checked_add), (:Base, :Checked), :checked_add)
function isscopedname(ex, modpath, name::Symbol)
isexpr(ex, :(.), 2) &&
(a = ex.args[2]; isa(a, QuoteNode) && a.value === name) &&
hasscope(ex.args[1], modpath)
hasscope(ex.args[1], modpath)
end
hasscope(modex, mod::Symbol) = modex === mod
hasscope(modex, mod::Tuple{Symbol}) = hasscope(modex, mod[1])
Expand Down
2 changes: 1 addition & 1 deletion src/reconstruct_loopset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Base.promote_rule(
::Type{UpperBoundedInteger{N,T}},
::Type{T}
) where {N,T<:Base.BitInteger} = T
Base.convert(::Type{T}, i::UpperBoundedInteger) where {T<:Number} =
Base.convert(::Type{T}, i::UpperBoundedInteger) where {T<:Integer} =
convert(T, i.i)
Base.convert(
::Type{UpperBoundedInteger{N,T}},
Expand Down
3 changes: 2 additions & 1 deletion src/simdfunctionals/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ end
Vectorized version of `sum`. Providing a function as the first argument
will apply the function to each element of `A` before summing.
"""
@inline vsum(f::F, A::AbstractArray{T}) where {F,T<:NativeTypes} = vmapreduce(f, +, A)
@inline vsum(f::F, A::AbstractArray{T}) where {F,T<:NativeTypes} =
vmapreduce(f, +, A)
@inline vsum(A::AbstractArray{T}) where {T<:NativeTypes} = vsum(identity, A)

length_one_axis(::Base.OneTo) = Base.OneTo(1)
Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -12,4 +14,5 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
37 changes: 37 additions & 0 deletions test/forwarddiffext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

using NNlib, LoopVectorization, VectorizationBase, ForwardDiff, Test
randnvec() = Vec(ntuple(_ -> randn(), pick_vector_width(Float64))...)

tovec(x::Vec{W,T}) where {W,T} = T[Tuple(x)...]
tovec(x::VecUnroll) = reduce(vcat, map(tovec, VectorizationBase.data(x)))
function tovec(x::ForwardDiff.Dual{T,V,N}) where {T,V,N}
v = tovec(ForwardDiff.value(x))
dv = map(tovec, Tuple(ForwardDiff.partials(x)))
D = ForwardDiff.Dual{T,eltype(v),N}
ret = Vector{D}(undef, length(v))
for i in eachindex(v)
ret[i] = ForwardDiff.Dual(v[i], map(Base.Fix2(Base.getindex, i), dv)...)
end
return ret
end


vx0 = randnvec()
vx1 = randnvec()
vx2 = randnvec()
vx3 = randnvec()
vx4 = randnvec()
vx5 = randnvec()

vd0 = ForwardDiff.Dual(vx0, vx1, vx2, vx3, vx4, vx5)

vu0 = VecUnroll((vx0, vx1))
vu1 = VecUnroll((vx2, vx3))
vu2 = VecUnroll((vx4, vx5))

vud = ForwardDiff.Dual(vu0, vu1, vu2)

@test reinterpret(Float64, tovec(NNlib.leakyrelu(vd0)))
reinterpret(Float64, NNlib.leakyrelu.(tovec(vd0)))
@test reinterpret(Float64, tovec(NNlib.leakyrelu(vud)))
reinterpret(Float64, NNlib.leakyrelu.(tovec(vud)))
1 change: 1 addition & 0 deletions test/grouptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ const START_TIME = time()
Pkg.activate(joinpath(precompiledir, "LVUser"))
@time include(joinpath(precompiledir, "precompile.jl"))
Pkg.activate(cproj)
@time include("forwarddiffext.jl")
end

end
Expand Down

2 comments on commit 73fb725

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/87874

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.164 -m "<description of version>" 73fb72539543e4a770f900a8992dddcaedd1d631
git push origin v0.12.164

Please sign in to comment.