Skip to content

Commit

Permalink
Test cleanup
Browse files Browse the repository at this point in the history
define NNlib.leakyrelu(::AbstractSIMD) for Julia 1.6 test, and don't test `check_order` for westmere like architectures (i.e., simd width of 16 bytes, 16 registers) because we don't have a set of validated reasonable results to compare with.
  • Loading branch information
chriselrod committed May 3, 2024
1 parent 51ee029 commit eeaa0b2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
15 changes: 15 additions & 0 deletions test/forwarddiffext.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Base: Forward

using NNlib, LoopVectorization, VectorizationBase, ForwardDiff, Test
randnvec() = Vec(ntuple(_ -> randn(), pick_vector_width(Float64))...)
Expand All @@ -15,6 +16,20 @@ function tovec(x::ForwardDiff.Dual{T,V,N}) where {T,V,N}
return ret
end

if LoopVectorization.ifelse !== Base.ifelse
@inline function NNlib.leakyrelu(
x::LoopVectorization.AbstractSIMD,
a = NNlib.oftf(x, NNlib.leakyrelu_a),
)
LoopVectorization.ifelse(x > zero(x), float(x), NNlib.oftf(x, a * x)) # max(a*x, x) is 3x slower
end
@inline function NNlib.leakyrelu(
x::ForwardDiff.Dual{<:Any,<:LoopVectorization.AbstractSIMD},
a = NNlib.oftf(x, NNlib.leakyrelu_a),
)
LoopVectorization.ifelse(x > zero(x), float(x), NNlib.oftf(x, a * x)) # max(a*x, x) is 3x slower
end
end

vx0 = randnvec()
vx1 = randnvec()
Expand Down
40 changes: 29 additions & 11 deletions test/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
Unum, Tnum = LoopVectorization.register_count() == 16 ? (2, 6) : (4, 6)
end
Unumt, Tnumt = LoopVectorization.register_count() == 16 ? (2, 6) : (5, 5)
if LoopVectorization.register_count() != 8
if (LoopVectorization.register_count() != 8) && (
(LoopVectorization.pick_vector_width(Float64) != 2) ||
(LoopVectorization.register_count() != 16)
)
@test @inferred(LoopVectorization.matmul_params()) == (Unum, Tnum)
end

Expand All @@ -30,7 +33,10 @@
end
)
lsAmulBt1 = LoopVectorization.loopset(AmulBtq1)
if LoopVectorization.register_count() != 8
if (LoopVectorization.register_count() != 8) && (
(LoopVectorization.pick_vector_width(Float64) != 2) ||
(LoopVectorization.register_count() != 16)
)
@test LoopVectorization.choose_order(lsAmulBt1) ==
(Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum)
end
Expand All @@ -43,7 +49,10 @@
end
)
lsAmulB1 = LoopVectorization.loopset(AmulBq1)
if LoopVectorization.register_count() != 8
if (LoopVectorization.register_count() != 8) && (
(LoopVectorization.pick_vector_width(Float64) != 2) ||
(LoopVectorization.register_count() != 16)
)
@test LoopVectorization.choose_order(lsAmulB1) ==
(Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum)
end
Expand All @@ -56,7 +65,10 @@
end
)
lsAmulB2 = LoopVectorization.loopset(AmulBq2)
if LoopVectorization.register_count() != 8
if (LoopVectorization.register_count() != 8) && (
(LoopVectorization.pick_vector_width(Float64) != 2) ||
(LoopVectorization.register_count() != 16)
)
@test LoopVectorization.choose_order(lsAmulB2) ==
(Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum)
end
Expand All @@ -70,11 +82,12 @@
end
)
lsAmulB3 = LoopVectorization.loopset(AmulBq3)
if LoopVectorization.register_count() != 8
if (LoopVectorization.register_count() != 8) && (
(LoopVectorization.pick_vector_width(Float64) != 2) ||
(LoopVectorization.register_count() != 16)
)
@test LoopVectorization.choose_order(lsAmulB3) ==
(Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum)
end
if LoopVectorization.register_count() != 8
for (fA, fB, v, Un, Tn) [
(identity, identity, :m, Unum, Tnum),
(adjoint, identity, :k, Unumt, Tnumt),
Expand Down Expand Up @@ -177,7 +190,8 @@
end
)
lsAmuladd = LoopVectorization.loopset(Amuladdq)
if LoopVectorization.register_count() != 8
if LoopVectorization.register_count() != 8 &&
LoopVectorization.pick_vector_width(Float64) != 2
@test LoopVectorization.choose_order(lsAmuladd) ==
(Symbol[:n, :m, :k], :m, :n, :m, Unum, Tnum)
end
Expand Down Expand Up @@ -410,9 +424,13 @@
@test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 3, 7)
end
elseif LoopVectorization.register_count() == 16
# @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 1, 6)
# @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 2, 4)
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :n, :m, :m, 3, 3)
if LoopVectorization.pick_vector_width(Float64) == 4
# @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 1, 6)
# @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 2, 4)
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :n, :m, :m, 3, 3)
elseif LoopVectorization.pick_vector_width(Float64) == 2
@test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :n, :m, :m, 3, 3)
end
end
function rank2AmulBavx!(C, Aₘ, Aₖ, B)
@turbo for m axes(C, 1), n axes(C, 2)
Expand Down

0 comments on commit eeaa0b2

Please sign in to comment.