From 18deacd5c034f49d38ebdba2cb3cfe4ddf22e90b Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Thu, 2 Jul 2020 14:21:28 +0200 Subject: [PATCH] multi-thread loop + single-thread BLAS --- Project.toml | 2 ++ src/gemm.jl | 22 +++++++++++++++------- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index c5b01f74c..3162b01cd 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.7.5" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -10,6 +11,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] +Compat = "3.13" Requires = "0.5, 1.0" julia = "1.3" diff --git a/src/gemm.jl b/src/gemm.jl index 91c7ce984..440aaae37 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -4,6 +4,8 @@ using LinearAlgebra using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc +using Compat: get_num_threads, set_num_threads + """ gemm!() @@ -89,22 +91,28 @@ for (gemm, elt) in gemm_datatype_mappings strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3) strC = Base.stride(C, 3) - for k in 1:size(A, 3) + old_threads = get_num_threads() + set_num_threads(1) + + Threads.@threads for k in 1:size(C, 3) + + ptrAk = ptrA + (k-1) * strA * sizeof($elt) + ptrBk = ptrB + (k-1) * strB * sizeof($elt) + ptrCk = ptrC + (k-1) * strC * sizeof($elt) + ccall((@blasfunc($(gemm)), libblas), Nothing, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}, Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}), transA, transB, m, n, - ka, alpha, ptrA, max(1,Base.stride(A,2)), - ptrB, max(1,Base.stride(B,2)), beta, ptrC, + ka, alpha, ptrAk, max(1,Base.stride(A,2)), + ptrBk, max(1,Base.stride(B,2)), beta, ptrCk, max(1,Base.stride(C,2))) - - ptrA += strA * sizeof($elt) - ptrB += strB * sizeof($elt) - ptrC += Base.stride(C, 3) * sizeof($elt) end + set_num_threads(old_threads) + C end end