Skip to content

Commit

Permalink
multi-thread loop + single-thread BLAS
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott authored and Michael Abbott committed Oct 24, 2020
1 parent d8c1761 commit 18deacd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.5"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Compat = "3.13"
Requires = "0.5, 1.0"
julia = "1.3"

Expand Down
22 changes: 15 additions & 7 deletions src/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using LinearAlgebra
using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc

using Compat: get_num_threads, set_num_threads

"""
gemm!()
Expand Down Expand Up @@ -89,22 +91,28 @@ for (gemm, elt) in gemm_datatype_mappings
strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3)
strC = Base.stride(C, 3)

for k in 1:size(A, 3)
old_threads = get_num_threads()
set_num_threads(1)

Threads.@threads for k in 1:size(C, 3)

ptrAk = ptrA + (k-1) * strA * sizeof($elt)
ptrBk = ptrB + (k-1) * strB * sizeof($elt)
ptrCk = ptrC + (k-1) * strC * sizeof($elt)

ccall((@blasfunc($(gemm)), libblas), Nothing,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt},
Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt},
Ref{BlasInt}),
transA, transB, m, n,
ka, alpha, ptrA, max(1,Base.stride(A,2)),
ptrB, max(1,Base.stride(B,2)), beta, ptrC,
ka, alpha, ptrAk, max(1,Base.stride(A,2)),
ptrBk, max(1,Base.stride(B,2)), beta, ptrCk,
max(1,Base.stride(C,2)))

ptrA += strA * sizeof($elt)
ptrB += strB * sizeof($elt)
ptrC += Base.stride(C, 3) * sizeof($elt)
end

set_num_threads(old_threads)

C
end
end
Expand Down

0 comments on commit 18deacd

Please sign in to comment.