From 5c7e228d9b82269c34555bc7020e156df41ebe14 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 3 Apr 2020 21:19:59 +0200 Subject: [PATCH 01/15] more unwrapping, plus strides & pointers --- src/batched/batchedadjtrans.jl | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/src/batched/batchedadjtrans.jl b/src/batched/batchedadjtrans.jl index 9684b7c17..8f71f7fa1 100644 --- a/src/batched/batchedadjtrans.jl +++ b/src/batched/batchedadjtrans.jl @@ -1,4 +1,5 @@ using LinearAlgebra + import Base: - _batched_doc = """ @@ -10,10 +11,13 @@ Equivalent to applying `transpose` or `adjoint` to each matrix `A[:,:,k]`. These exist to control how `batched_mul` behaves, as it operated on such matrix slices of an array with `ndims(A)==3`. - BatchedTranspose{T, N, S} <: AbstractBatchedMatrix{T, N} - BatchedAdjoint{T, N, S} +For arrays of real numbers, `batched_transpose(A) == PermutedDimsArray(A, (2,1,3))`, +which is a more widely-supported wrapper, and also understood by `batched_mul`. + + BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3} + BatchedAdjoint{T, S} -Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose`. +Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose` etc. """ @doc _batched_doc @@ -36,6 +40,13 @@ end batched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A) batched_adjoint(A::BatchedAdjoint) = A.parent +batched_adjoint(A::BatchedTranspose{<:Real}) = A.parent +batched_transpose(A::BatchedAdjoint{<:Real}) = A.parent +batched_adjoint(A::PermutedDimsArray{<:Real,3,(2,1,3)}) = A.parent +batched_transpose(A::PermutedDimsArray{<:Number,3,(2,1,3)}) = A.parent +# if you can't unwrap, put BatchedAdjoint outside (for dispatch): +batched_transpose(A::BatchedAdjoint{<:Complex}) = BatchedAdjoint(BatchedTranspose(A.parent)) + BatchedAdjoint(A) = BatchedAdjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A) BatchedTranspose(A) = BatchedTranspose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A) @@ -65,6 +76,18 @@ Base.parent(A::BatchedAdjOrTrans) = A.parent (-)(A::BatchedAdjoint) = BatchedAdjoint( -A.parent) (-)(A::BatchedTranspose) = BatchedTranspose(-A.parent) -Base.copy(A::BatchedTranspose) = BatchedTranspose(copy(A.parent)) -Base.copy(A::BatchedAdjoint) = BatchedAdjoint(copy(A.parent)) +# C interface +function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}) + sp = strides(A.parent) + (sp[2], sp[1], sp[3]) +end + +function Base.stride(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}, d::Integer) + d == 1 && return Base.stride(A.parent, 2) + d == 2 && return Base.stride(A.parent, 1) + Base.stride(A.parent, d) +end + +Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} = + Base.unsafe_convert(Ptr{T}, parent(A)) From 44faf05649d9f321ab3f3e743253eb8f3018a3dc Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 3 Apr 2020 21:21:15 +0200 Subject: [PATCH 02/15] alternative batched_gemm! setup, using strides & storage_type --- src/batched/batchedmul.jl | 186 +++++++++++++++++++++++++++++++++----- test/batchedmul.jl | 42 +++++++++ 2 files changed, 203 insertions(+), 25 deletions(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 46814dde7..11a015e70 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -1,13 +1,23 @@ -# batch-wise matrix multiplication -# wrapper for batched_gemm! + export batched_mul, batched_transpose, batched_adjoint include("./batchedadjtrans.jl") +using LinearAlgebra: BlasFloat, Transpose, Adjoint + +_unbatch(A) = A +_unbatch(A::BatchedAdjOrTrans) = parent(A) + """ batched_mul(A, B) -> C Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`. + +To transpose each matrix apply `batched_transpose` to the array, +and similarly `batched_adjoint`. Other permutations are also handled efficiently, +provided that the batch index `k` is not the first dimension of the underlying array. +Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine, +but `PermutedDimsArray(::Array, (3,2,1))` must use the fallback `batched_mul_generic!`. """ function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch")) @@ -18,47 +28,173 @@ end """ batched_mul!(C, A, B) -> C + batched_mul!(C, A, B, α=1, β=0) + +In-place batched matrix multiplication, equivalent to +`mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)` for all `k`. -In-place batched matrix multiplication, -equivalent to `mul!(C[:,:,k], A[:,:,k], B[:,:,k])` for all `k`. +This will call `batched_gemm!` whenever possible. For real arrays this means that, +for `X ∈ [A,B,C]`, either `strides(X,1)==1` or `strides(X,2)==1`, the latter may +be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`. + +For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen, +and in this case `stride(A::BatchedAdjoint,2) == 1` is not optional. + +The fallback method calls 5-argument `mul!` on Julia 1.3 and later, +on earlier verions it will thrown an error if `α!=1` or `β!=0`. """ -function batched_mul! end +function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}, + α::Number=one(T), β::Number=zero(T)) where {T} + _batched_mul!(storage_type(C,A,B), C, A, B, α, β) + C +end -_unbatch(A) = A -_unbatch(A::BatchedAdjOrTrans) = A.parent +_batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β) -# batched_gemm! +function _batched_mul!(::Array{T}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} -const _GemmFloat = Union{Float64, Float32, ComplexF64, ComplexF32} + is_strided(C) && is_strided(_unbatch(A)) && is_strided(_unbatch(B)) || + return batched_mul_generic!(C, A, B, α, β) -_BATCHED_GEMM_LIST = [ - (:(StridedArray{T, 3}), 'N'), - (:(BatchedTranspose{T, <:StridedArray{T, 3}}), 'T'), - (:(BatchedAdjoint{T, <:StridedArray{T, 3}}), 'C') -] + if Base.stride(C,1) == 1 + elseif Base.stride(C,2) == 1 + return batched_mul!(batched_transpose(C), batched_transpose(B), batched_transpose(A), α, β) + else + return batched_mul_generic!(C, A, B, α, β) + end -for (TA, transA) in _BATCHED_GEMM_LIST, (TB, transB) in _BATCHED_GEMM_LIST - @eval function batched_mul!(C::StridedArray{T, 3}, A::$TA, B::$TB) where {T<:_GemmFloat} - batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C) - C + blasA, transA = if A isa BatchedAdjoint + Base.stride(parent(A),1) == 1 || return batched_mul_generic!(C, A, B, α, β) + parent(A), 'C' + elseif Base.stride(A,1) == 1 + A, 'N' + elseif Base.stride(A,2) == 1 + batched_transpose(A), 'T' + else + return batched_mul_generic!(C, A, B, α, β) + end + + blasB, transB = if B isa BatchedAdjoint + Base.stride(parent(B),1) == 1 || return batched_mul_generic!(C, A, B, α, β) + parent(B), 'C' + elseif Base.stride(B,1) == 1 + B, 'N' + elseif Base.stride(B,2) == 1 + batched_transpose(B), 'T' + else + return batched_mul_generic!(C, A, B, α, β) end + + _batched_gemm!(transA, transB, convert(T,α), blasA, blasB, convert(T,β), C) + C end -# fallback +_batched_gemm!(::Type{<:Array}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = + batched_gemm!(transA, transB, α, A, B, β, C) _BATCHED_LIST = [ (:(AbstractArray{<:Any, 3}), :identity), - (:(BatchedTranspose{<:Any, <:AbstractArray{<:Any, 3}}), :transpose), - (:(BatchedAdjoint{<:Any, <:AbstractArray{<:Any, 3}}), :adjoint) + (:BatchedTranspose, :transpose), + (:BatchedAdjoint, :adjoint), ] for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST - @eval function batched_mul!(C::AbstractArray{<:Any, 3}, A::$TA, B::$TB) + + @eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB, + α::Number=one(T), β::Number=zero(T)) where {T} axes(A, 3) == axes(B, 3) == axes(C, 3) || throw(DimensionMismatch("batch size mismatch")) @debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C) - A′, B′ = _unbatch(A), _unbatch(B) - @inbounds for k in axes(C, 3) - @views mul!(C[:,:,k], $fA(A′[:,:,k]), $fB(B′[:,:,k])) + Abase, Bbase = _unbatch(A), _unbatch(B) + + if VERSION >= v"1.3" + @inbounds for k in axes(C, 3) + @views mul!(C[:,:,k], $fA(Abase[:,:,k]), $fB(Bbase[:,:,k]), convert(T,α), convert(T,β)) + end + else + α==1 && β==0 || throw(ArgumentError("5-arg batched_mul_generic! does not work on Julia < 1.3")) + @inbounds for k in axes(C, 3) + @views mul!(C[:,:,k], $fA(Abase[:,:,k]), $fB(Bbase[:,:,k])) + end end + C end + +end + +""" + storage_type(A) -> Type + +Removes all wrappers to return the `Array` or `CuArray` (or whatever) type within. +``` +julia> view(reshape(ones(10)',2,5),:, 3:4) |> storage_type +Array{Float64,1} + +julia> reshape(sparse(rand(10)), 5,2) |> storage_type +SparseVector{Float64,Int64} +``` +""" +function storage_type(A::AbstractArray) + P = parent(A) + typeof(A) === typeof(P) ? typeof(A) : storage_type(P) +end +storage_type(A) = typeof(A) + +""" + storage_type(A, B, C, ...) -> Type + +Reduces with `Base.promote_typejoin`, in order that this conveys useful information +for dispatching to BLAS, rather than information about the storage to allocate: +``` +julia> storage_type(rand(2), rand(Float32, 2)) +Array{T,1} where T + +julia> eltype(ans) <: LinearAlgebra.BlasFloat +false + +julia> storage_type(rand(2), rand(2,3), rand(2,3,4)) +Array{Float64,N} where N +``` +""" +storage_type(A, Bs...) = Base.promote_typejoin(storage_type(A), storage_type(Bs...)) + + +""" + is_strided(A::AbstractArray) -> Bool + +This generalises `A isa StridedArray` to treat wrappers like `A::PermutedDimsArray`, +for which it returns `is_strided(parent(A))`. + +Other wrappers (defined outside Base, LinearAlgebra) are assumed not to break +strided-ness, and hence also return `is_strided(parent(A))`. +This correctly handles things like `NamedDimsArray` wihch don't alter indexing. +However, it's a little pessimistic in that e.g. a `view` of such a container will return +`false`, even in cases where the same `view` of `parent(A)` would be a `StridedArray`. + +`A::Transpose` doesn't currently define `strides`, until that's fixed this returns `false`. +The PR to fix that only defines `strides(::Adjoint{T})` for `T<:Real`, so this will follow. +""" +is_strided(A::StridedArray) = true +is_strided(A) = false +function is_strided(A::AbstractArray) + M = parentmodule(typeof(A)) + if parent(A) === A # SparseMatrix, StaticArray, etc + false + elseif M === Base || M === Core || M ===LinearAlgebra + # bad reshapes, etc, plus Diagonal, UpperTriangular, etc. + false + else + is_strided(parent(A)) # PermutedDimsArray, NamedDimsArray + end +end + +is_strided(A::BatchedAdjoint) = eltype(A) <: Real && is_strided(parent(A)) +is_strided(A::BatchedTranspose) = is_strided(parent(A)) + +if hasmethod(Base.strides, Tuple{LinearAlgebra.Transpose}) + # https://github.com/JuliaLang/julia/pull/29135 + is_strided(A::LinearAlgebra.Transpose) = is_strided(parent(A)) + is_strided(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A)) +else + is_strided(A::LinearAlgebra.Transpose) = false + is_strided(A::LinearAlgebra.Adjoint) = false end diff --git a/test/batchedmul.jl b/test/batchedmul.jl index e85d36c2e..fd3e3ea68 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -1,3 +1,6 @@ +using NNlib, Test, LinearAlgebra +using NNlib: storage_type, is_strided + function bmm_test(a,b; transA = false, transB = false) bs = size(a,3) transA && (a = permutedims(a, [2,1,3])) @@ -121,3 +124,42 @@ end @test _X != _copyX end end + +@testset "storage_type" + + @test storage_type(transpose(reshape(view(rand(10), 2:9),4,:))) == Vector{Float64} + @test storage_type(transpose(reshape(view(1:10, 2:9),4,:))) == UnitRange{Int} + + @test storage_type(rand(2), rand(Float32, 2)) == Vector{<:Any} + @test storage_type(rand(2), rand(2,3)', rand(2,3,4)) == Array{Float64} + @test storage_type([1,2,3], 4:5) == AbstractVector{Int} + +end + +@testset "is_strided" begin + + M = ones(10,10) + + @test is_strided(M) + @test is_strided(view(M, 1:2:5,:)) + @test is_strided(PermutedDimsArray(M, (2,1))) + + @test !is_strided(reshape(view(M, 1:2:10,:), 10,:)) + @test !is_strided((M.+im)') + @test !is_strided(Diagonal(ones(3))) + + A = ones(2,2,2) + + @test is_strided(batched_adjoint(A)) + @test is_strided(batched_transpose(A)) + @test !is_strided(batched_adjoint(A .+ im)) + @test is_strided(batched_transpose(A .+ im)) + + #= + using SparseArrays + @test !is_strided(sparse(M)) + using NamedDims + @test is_strided(NamedDimsArray(M,(:a, :b))) # and 0.029 ns, 0 allocations + =# + +end From 915660cb59cd5fc96bcff7e4cdb42bd28d1ddd23 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 3 Apr 2020 21:49:27 +0200 Subject: [PATCH 03/15] fix strides in batched_gemm --- src/gemm.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gemm.jl b/src/gemm.jl index 3a66b3651..440bb31aa 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -94,9 +94,9 @@ for (gemm, elt) in gemm_datatype_mappings ptrB, max(1,Base.stride(B,2)), beta, ptrC, max(1,Base.stride(C,2))) - ptrA += size(A, 1) * size(A, 2) * sizeof($elt) - ptrB += size(B, 1) * size(B, 2) * sizeof($elt) - ptrC += size(C, 1) * size(C, 2) * sizeof($elt) + ptrA += Base.stride(A, 3) * sizeof($elt) + ptrB += Base.stride(B, 3) * sizeof($elt) + ptrC += Base.stride(C, 3) * sizeof($elt) end C From f60dc8f860c47571976513141fd052f51452bd61 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 3 Apr 2020 21:50:00 +0200 Subject: [PATCH 04/15] fixup --- src/batched/batchedmul.jl | 18 +++++++++++------- test/batchedmul.jl | 4 ++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 11a015e70..732643a3e 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -17,7 +17,10 @@ To transpose each matrix apply `batched_transpose` to the array, and similarly `batched_adjoint`. Other permutations are also handled efficiently, provided that the batch index `k` is not the first dimension of the underlying array. Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine, -but `PermutedDimsArray(::Array, (3,2,1))` must use the fallback `batched_mul_generic!`. +but `PermutedDimsArray(::Array, (3,2,1))` will use the fallback `batched_mul_generic!`. + +There is an `@debug` message produced by `batched_mul_generic!`, +setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display this. """ function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch")) @@ -40,8 +43,8 @@ be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen, and in this case `stride(A::BatchedAdjoint,2) == 1` is not optional. -The fallback method calls 5-argument `mul!` on Julia 1.3 and later, -on earlier verions it will thrown an error if `α!=1` or `β!=0`. +The fallback method calls 5-argument `mul!` on Julia 1.3 and later. +On earlier verions it will thrown an error if `α!=1` or `β!=0`. """ function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}, α::Number=one(T), β::Number=zero(T)) where {T} @@ -51,10 +54,9 @@ end _batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β) -function _batched_mul!(::Array{T}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} +function _batched_mul!(CT::Type{<:Array{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} - is_strided(C) && is_strided(_unbatch(A)) && is_strided(_unbatch(B)) || - return batched_mul_generic!(C, A, B, α, β) + are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β) if Base.stride(C,1) == 1 elseif Base.stride(C,2) == 1 @@ -85,7 +87,7 @@ function _batched_mul!(::Array{T}, C, A, B, α::Number, β::Number) where {T<:Bl return batched_mul_generic!(C, A, B, α, β) end - _batched_gemm!(transA, transB, convert(T,α), blasA, blasB, convert(T,β), C) + _batched_gemm!(CT, transA, transB, convert(T,α), blasA, blasB, convert(T,β), C) C end @@ -198,3 +200,5 @@ else is_strided(A::LinearAlgebra.Transpose) = false is_strided(A::LinearAlgebra.Adjoint) = false end + +are_strided(As...) = mapfoldl(is_strided, &, As; init=true) diff --git a/test/batchedmul.jl b/test/batchedmul.jl index fd3e3ea68..c07fd8a65 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -1,5 +1,5 @@ using NNlib, Test, LinearAlgebra -using NNlib: storage_type, is_strided +using NNlib: storage_type, is_strided, batched_mul! function bmm_test(a,b; transA = false, transB = false) bs = size(a,3) @@ -125,7 +125,7 @@ end end end -@testset "storage_type" +@testset "storage_type" begin @test storage_type(transpose(reshape(view(rand(10), 2:9),4,:))) == Vector{Float64} @test storage_type(transpose(reshape(view(1:10, 2:9),4,:))) == UnitRange{Int} From 8775abc0c115e2d48c50e4f53e258c45a2b5ec1a Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 4 Apr 2020 00:04:12 +0200 Subject: [PATCH 05/15] Array -> DenseArray --- src/batched/batchedmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 732643a3e..cfa4719f9 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -54,7 +54,7 @@ end _batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β) -function _batched_mul!(CT::Type{<:Array{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} +function _batched_mul!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β) From db83d76f905b59f71eddcf2595ce5cc76b544f51 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 4 Apr 2020 09:52:20 +0200 Subject: [PATCH 06/15] allow size(A,3)==1 with size(B,3)==size(C,3) --- src/batched/batchedmul.jl | 21 ++++++++++++++------- src/gemm.jl | 14 ++++++++++---- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index cfa4719f9..fe61692f2 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -12,6 +12,7 @@ _unbatch(A::BatchedAdjOrTrans) = parent(A) batched_mul(A, B) -> C Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`. +If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. To transpose each matrix apply `batched_transpose` to the array, and similarly `batched_adjoint`. Other permutations are also handled efficiently, @@ -23,9 +24,10 @@ There is an `@debug` message produced by `batched_mul_generic!`, setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display this. """ function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} - axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch")) + size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 || + throw(DimensionMismatch("batch size mismatch: A != B")) T = promote_type(T1, T2) - C = similar(A, T, (axes(A, 1), axes(B, 2), axes(A, 3))) + C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3)))) batched_mul!(C, A, B) end @@ -35,6 +37,7 @@ end In-place batched matrix multiplication, equivalent to `mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)` for all `k`. +If `size(B,3) == 1` then every batch uses `B[:,:,1]` instead. This will call `batched_gemm!` whenever possible. For real arrays this means that, for `X ∈ [A,B,C]`, either `strides(X,1)==1` or `strides(X,2)==1`, the latter may @@ -60,6 +63,7 @@ function _batched_mul!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Numbe if Base.stride(C,1) == 1 elseif Base.stride(C,2) == 1 + @debug "transposing C = A * B into Cᵀ = Bᵀ * Aᵀ" size(C) strides(C) return batched_mul!(batched_transpose(C), batched_transpose(B), batched_transpose(A), α, β) else return batched_mul_generic!(C, A, B, α, β) @@ -103,18 +107,21 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST @eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB, α::Number=one(T), β::Number=zero(T)) where {T} - axes(A, 3) == axes(B, 3) == axes(C, 3) || throw(DimensionMismatch("batch size mismatch")) + size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C")) + size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C")) @debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C) Abase, Bbase = _unbatch(A), _unbatch(B) + sA, oA = size(A,3) == 1 ? (0,1) : (1,0) + sB, oB = size(B,3) == 1 ? (0,1) : (1,0) if VERSION >= v"1.3" - @inbounds for k in axes(C, 3) - @views mul!(C[:,:,k], $fA(Abase[:,:,k]), $fB(Bbase[:,:,k]), convert(T,α), convert(T,β)) + @inbounds for k in 1:size(C,3) + @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), convert(T,α), convert(T,β)) end else α==1 && β==0 || throw(ArgumentError("5-arg batched_mul_generic! does not work on Julia < 1.3")) - @inbounds for k in axes(C, 3) - @views mul!(C[:,:,k], $fA(Abase[:,:,k]), $fB(Bbase[:,:,k])) + @inbounds for k in 1:size(C,3) + @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB])) end end diff --git a/src/gemm.jl b/src/gemm.jl index 440bb31aa..91c7ce984 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -67,13 +67,15 @@ for (gemm, elt) in gemm_datatype_mappings beta::($elt), C::AbstractArray{$elt, 3}) @assert !Base.has_offset_axes(A, B, C) - @assert size(A, 3) == size(B, 3) == size(C, 3) "batch size mismatch" + @assert size(A, 3) == 1 || size(A, 3) == size(C, 3) "batch size mismatch: A != C" + @assert size(B, 3) == 1 || size(B, 3) == size(C, 3) "batch size mismatch: B != C" + m = size(A, transA == 'N' ? 1 : 2) ka = size(A, transA == 'N' ? 2 : 1) kb = size(B, transB == 'N' ? 1 : 2) n = size(B, transB == 'N' ? 2 : 1) if ka != kb || m != size(C,1) || n != size(C,2) - throw(DimensionMismatch("A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))")) + throw(DimensionMismatch("A1 has size ($m,$ka), B1 has size ($kb,$n), C1 has size $(size(C)[1:2])")) end LinearAlgebra.BLAS.chkstride1(A) LinearAlgebra.BLAS.chkstride1(B) @@ -83,6 +85,10 @@ for (gemm, elt) in gemm_datatype_mappings ptrB = Base.unsafe_convert(Ptr{$elt}, B) ptrC = Base.unsafe_convert(Ptr{$elt}, C) + strA = size(A, 3) == 1 ? 0 : Base.stride(A, 3) + strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3) + strC = Base.stride(C, 3) + for k in 1:size(A, 3) ccall((@blasfunc($(gemm)), libblas), Nothing, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, @@ -94,8 +100,8 @@ for (gemm, elt) in gemm_datatype_mappings ptrB, max(1,Base.stride(B,2)), beta, ptrC, max(1,Base.stride(C,2))) - ptrA += Base.stride(A, 3) * sizeof($elt) - ptrB += Base.stride(B, 3) * sizeof($elt) + ptrA += strA * sizeof($elt) + ptrB += strB * sizeof($elt) ptrC += Base.stride(C, 3) * sizeof($elt) end From b264c97dce22c9ccbcce49a5aa2c8f85b7ff1096 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 4 Apr 2020 09:58:38 +0200 Subject: [PATCH 07/15] tests for permutations, 5-arg mul, trivial batches --- test/batchedmul.jl | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/test/batchedmul.jl b/test/batchedmul.jl index c07fd8a65..5690bf3c1 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -25,9 +25,18 @@ function bmm_adjtest(a,b; adjA = false, adjB = false) cat(c...; dims = 3) end +function half_batched_mul(x,y) + @assert size(y,3) == 1 + d = size(x,2) + x_mat = reshape(permutedims(x, (1,3,2)),:,d) + y_mat = reshape(y,d,:) + z_mat = x_mat * y_mat + permutedims(reshape(z_mat, size(x,1), size(x,3), :), (1,3,2)) +end @testset "batched_mul: Float64 * $TB" for TB in [Float64, Float32] + # Real A = randn(7,5,3) B = randn(TB, 5,7,3) C = randn(7,6,3) @@ -37,7 +46,7 @@ end @test batched_mul(batched_transpose(A), C) ≈ bmm_test(A, C; transA = true) @test batched_mul(A, batched_transpose(A)) ≈ bmm_test(A, A; transB = true) - + # Complex cA = randn(Complex{Float64}, 7,5,3) cB = randn(Complex{TB}, 5,7,3) cC = randn(Complex{Float64}, 7,6,3) @@ -47,9 +56,13 @@ end @test batched_mul(batched_adjoint(cA), cC) ≈ bmm_adjtest(cA, cC; adjA = true) @test batched_mul(cA, batched_adjoint(cA)) ≈ bmm_adjtest(cA, cA; adjB = true) + # Wrappers which cancel @test batched_transpose(batched_transpose(A)) === A + @test batched_transpose(PermutedDimsArray(A, (2,1,3))) === A @test batched_adjoint(batched_adjoint(cA)) === cA + @test batched_transpose(batched_adjoint(cA)) isa NNlib.BatchedAdjoint + # Integers TBi = TB==Float64 ? Int64 : Int32 iA = rand(1:99, 7,5,3) iB = TB.(rand(1:99, 5,7,3)) @@ -57,10 +70,37 @@ end @test batched_mul(iA, iB) == bmm_adjtest(iA, iB) @test batched_mul(cA, iB) ≈ bmm_adjtest(cA, iB) + # Errors @test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 2,2,10)) @test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 10,2,2)) @test_throws Exception batched_mul!(zeros(2,2,10), rand(2,2,2), rand(TB, 2,2,2)) + # PermutedDimsArrays + for perm in [(1,3,2), (2,1,3)], fun in [identity, batched_adjoint], ty in [identity, complex] + A = randn(ty(Float64), 4,4,4) + B = randn(ty(TB), 4,4,4) + @test batched_mul(fun(A), PermutedDimsArray(B, perm)) ≈ batched_mul(fun(A), permutedims(B, perm)) + @test batched_mul(fun(PermutedDimsArray(A, perm)), B) ≈ batched_mul(fun(permutedims(A, perm)), B) + # when TB=Float64, only the case (2,1,3) batched_adjoint complex goes to fallback + end + + # PermutedDimsArray output + A′ = randn(4,3,2) + B′ = batched_adjoint(randn(TB, 5,3,2)) + C1 = batched_mul(A′, B′) # size 4,5,2 + C2 = PermutedDimsArray(zeros(5,2,4), (3,1,2)) # size 4,5,2 + @test C1 ≈ batched_mul!(C2, A′, B′) # Float64: "Debug: transposing C = A * B into Cᵀ = Bᵀ * Aᵀ" + @test C1 ≈ C2 + + # 5-arg mul! + @test 10 .* C1 ≈ batched_mul!(C2, A′, B′, 10) + C2 .= 10 + @test C1 .+ 100 ≈ batched_mul!(C2, A′, B′, 1, 10) + + # Trivial batches for B + D′ = randn(TB, 3,5,1) + @test size(batched_mul(A′,D′)) == (4,5,2) + @test batched_mul(A′,D′) ≈ half_batched_mul(A′, D′) end @testset "BatchedAdjOrTrans interface * $TB" for TB in [Float64, Float32] From 71fc51e43959ed5a8e1299b57142f16c4d39160b Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 4 Apr 2020 11:18:42 +0200 Subject: [PATCH 08/15] fixes & tweaks --- src/batched/batchedadjtrans.jl | 6 +++--- src/batched/batchedmul.jl | 20 ++++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/batched/batchedadjtrans.jl b/src/batched/batchedadjtrans.jl index 8f71f7fa1..fc7b5fdef 100644 --- a/src/batched/batchedadjtrans.jl +++ b/src/batched/batchedadjtrans.jl @@ -9,10 +9,10 @@ _batched_doc = """ Equivalent to applying `transpose` or `adjoint` to each matrix `A[:,:,k]`. These exist to control how `batched_mul` behaves, -as it operated on such matrix slices of an array with `ndims(A)==3`. +as it operates on such matrix slices of an array with `ndims(A)==3`. -For arrays of real numbers, `batched_transpose(A) == PermutedDimsArray(A, (2,1,3))`, -which is a more widely-supported wrapper, and also understood by `batched_mul`. +`PermutedDimsArray(A, (2,1,3))` is equivalent to `batched_transpose(A)`, +and is also understood by `batched_mul` (and more widely supported elsewhere). BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3} BatchedAdjoint{T, S} diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index fe61692f2..10c0618b5 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -45,9 +45,6 @@ be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen, and in this case `stride(A::BatchedAdjoint,2) == 1` is not optional. - -The fallback method calls 5-argument `mul!` on Julia 1.3 and later. -On earlier verions it will thrown an error if `α!=1` or `β!=0`. """ function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}, α::Number=one(T), β::Number=zero(T)) where {T} @@ -57,7 +54,10 @@ end _batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β) -function _batched_mul!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} +_batched_mul!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} = + _batched_try_gemm!(CT, C, A, B, α, β) + +function _batched_try_gemm!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β) @@ -69,7 +69,7 @@ function _batched_mul!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Numbe return batched_mul_generic!(C, A, B, α, β) end - blasA, transA = if A isa BatchedAdjoint + blasA, transA = if A isa BatchedAdjoint && T <: Complex Base.stride(parent(A),1) == 1 || return batched_mul_generic!(C, A, B, α, β) parent(A), 'C' elseif Base.stride(A,1) == 1 @@ -80,7 +80,7 @@ function _batched_mul!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Numbe return batched_mul_generic!(C, A, B, α, β) end - blasB, transB = if B isa BatchedAdjoint + blasB, transB = if B isa BatchedAdjoint && T <: Complex Base.stride(parent(B),1) == 1 || return batched_mul_generic!(C, A, B, α, β) parent(B), 'C' elseif Base.stride(B,1) == 1 @@ -118,11 +118,15 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST @inbounds for k in 1:size(C,3) @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), convert(T,α), convert(T,β)) end - else - α==1 && β==0 || throw(ArgumentError("5-arg batched_mul_generic! does not work on Julia < 1.3")) + elseif α==1 && β==0 @inbounds for k in 1:size(C,3) @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB])) end + else + @debug "since there is no 5-arg mul!, calling C1 .= α .* (A1 * B1) .+ β .* C" α β + @inbounds for k in 1:size(C,3) + @views C[:,:,k] .= α .* $fA(Abase[:,:,k*sA+oA]) * $fB(Bbase[:,:,k*sB+oB]) .+ β .* C[:,:,k] + end end C From ffedecc16f263383ec433117f9610405e2cbe6a4 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 4 Apr 2020 11:31:17 +0200 Subject: [PATCH 09/15] better promotion of alpha, beta --- src/batched/batchedmul.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 10c0618b5..581b5cfaf 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -59,6 +59,9 @@ _batched_mul!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where function _batched_try_gemm!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} + alpha, beta = promote(α, β, zero(T)) # trick from https://github.com/JuliaLang/julia/pull/33229 + alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β) + are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β) if Base.stride(C,1) == 1 @@ -91,7 +94,7 @@ function _batched_try_gemm!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β:: return batched_mul_generic!(C, A, B, α, β) end - _batched_gemm!(CT, transA, transB, convert(T,α), blasA, blasB, convert(T,β), C) + _batched_gemm!(CT, transA, transB, alpha, blasA, blasB, beta, C) C end @@ -116,7 +119,7 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST if VERSION >= v"1.3" @inbounds for k in 1:size(C,3) - @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), convert(T,α), convert(T,β)) + @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), α, β) end elseif α==1 && β==0 @inbounds for k in 1:size(C,3) From ab998e59ba0ee7dc690830191133d3e4c9505c51 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 4 Apr 2020 19:59:04 +0200 Subject: [PATCH 10/15] make copies to avoid generic --- src/batched/batchedmul.jl | 61 +++++++++++++++++++++++++++++---------- test/batchedmul.jl | 14 +++++---- 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 581b5cfaf..406f86996 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -15,20 +15,50 @@ Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. To transpose each matrix apply `batched_transpose` to the array, -and similarly `batched_adjoint`. Other permutations are also handled efficiently, +and similarly `batched_adjoint`. Other permutations are also handled by BLAS, provided that the batch index `k` is not the first dimension of the underlying array. -Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine, -but `PermutedDimsArray(::Array, (3,2,1))` will use the fallback `batched_mul_generic!`. +Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine. -There is an `@debug` message produced by `batched_mul_generic!`, -setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display this. +However `PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS, +and will thus be copied as this is faster than the fallback method `batched_mul_generic!`. + +Both this `copy` and `batched_mul_generic!` produce `@debug` messages, +and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them. """ function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != B")) - T = promote_type(T1, T2) + _batched_mul(storage_typejoin(A, B), A, B) +end + +function _batched_mul(::Type, A, B) + T = promote_type(eltype(A), eltype(B)) C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3)))) batched_mul!(C, A, B) + C +end +function _batched_mul(::Type{<:DenseArray{T}}, A, B) where {T<:BlasFloat} + C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3)))) + batched_mul!(C, _copy_if_faster(A), _copy_if_faster(B)) + C +end + +function _copy_if_faster(X::AbstractArray{<:Number, 3}) + is_strided(X) || return X + if Base.stride(X, 3) == 1 && Base.stride(X, 1) != 1 + @debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(X) + return copy(X) + end + X +end +function _copy_if_faster(X::BatchedAdjoint{<:Complex}) + Xbase = _unbatch(X) + is_strided(Xbase) || return X + if Base.stride(Xbase, 1) != 1 + @debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(_unbatch(X)) + return copy(X) # or batched_adjoint(copy(Xbase)), may be better on GPU? + end + X end """ @@ -42,22 +72,23 @@ If `size(B,3) == 1` then every batch uses `B[:,:,1]` instead. This will call `batched_gemm!` whenever possible. For real arrays this means that, for `X ∈ [A,B,C]`, either `strides(X,1)==1` or `strides(X,2)==1`, the latter may be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`. +Unlike `batched_mul` this will never make a copy. For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen, and in this case `stride(A::BatchedAdjoint,2) == 1` is not optional. """ function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}, α::Number=one(T), β::Number=zero(T)) where {T} - _batched_mul!(storage_type(C,A,B), C, A, B, α, β) + _batched_mul!(storage_typejoin(C,A,B), C, A, B, α, β) C end _batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β) -_batched_mul!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} = - _batched_try_gemm!(CT, C, A, B, α, β) +_batched_mul!(DT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} = + _batched_try_gemm!(DT, C, A, B, α, β) -function _batched_try_gemm!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} +function _batched_try_gemm!(DT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} alpha, beta = promote(α, β, zero(T)) # trick from https://github.com/JuliaLang/julia/pull/33229 alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β) @@ -94,7 +125,7 @@ function _batched_try_gemm!(CT::Type{<:DenseArray{T}}, C, A, B, α::Number, β:: return batched_mul_generic!(C, A, B, α, β) end - _batched_gemm!(CT, transA, transB, alpha, blasA, blasB, beta, C) + _batched_gemm!(DT, transA, transB, alpha, blasA, blasB, beta, C) C end @@ -156,10 +187,10 @@ end storage_type(A) = typeof(A) """ - storage_type(A, B, C, ...) -> Type + storage_typejoin(A, B, C, ...) -> Type Reduces with `Base.promote_typejoin`, in order that this conveys useful information -for dispatching to BLAS, rather than information about the storage to allocate: +for dispatching to BLAS. It does not tell you what container to allocate: ``` julia> storage_type(rand(2), rand(Float32, 2)) Array{T,1} where T @@ -171,8 +202,8 @@ julia> storage_type(rand(2), rand(2,3), rand(2,3,4)) Array{Float64,N} where N ``` """ -storage_type(A, Bs...) = Base.promote_typejoin(storage_type(A), storage_type(Bs...)) - +storage_typejoin(A, Bs...) = Base.promote_typejoin(storage_type(A), storage_typejoin(Bs...)) +storage_typejoin(A) = storage_type(A) """ is_strided(A::AbstractArray) -> Bool diff --git a/test/batchedmul.jl b/test/batchedmul.jl index 5690bf3c1..8afd8d47d 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -1,5 +1,6 @@ using NNlib, Test, LinearAlgebra -using NNlib: storage_type, is_strided, batched_mul! +using NNlib: storage_type, storage_typejoin, is_strided, + batched_mul!, _unbatch, _copy_if_faster, BatchedAdjoint function bmm_test(a,b; transA = false, transB = false) bs = size(a,3) @@ -76,12 +77,13 @@ end @test_throws Exception batched_mul!(zeros(2,2,10), rand(2,2,2), rand(TB, 2,2,2)) # PermutedDimsArrays - for perm in [(1,3,2), (2,1,3)], fun in [identity, batched_adjoint], ty in [identity, complex] + for perm in [(1,3,2), (2,1,3), (3,2,1)], fun in [identity, batched_adjoint], ty in [identity, complex] A = randn(ty(Float64), 4,4,4) B = randn(ty(TB), 4,4,4) @test batched_mul(fun(A), PermutedDimsArray(B, perm)) ≈ batched_mul(fun(A), permutedims(B, perm)) @test batched_mul(fun(PermutedDimsArray(A, perm)), B) ≈ batched_mul(fun(permutedims(A, perm)), B) - # when TB=Float64, only the case (2,1,3) batched_adjoint complex goes to fallback + # when TB=Float64, only the case perm=(2,1,3); fun=batched_adjoint; ty=complex; goes to fallback + # but all the perm=(3,2,1); cases copy their inputs. end # PermutedDimsArray output @@ -170,9 +172,9 @@ end @test storage_type(transpose(reshape(view(rand(10), 2:9),4,:))) == Vector{Float64} @test storage_type(transpose(reshape(view(1:10, 2:9),4,:))) == UnitRange{Int} - @test storage_type(rand(2), rand(Float32, 2)) == Vector{<:Any} - @test storage_type(rand(2), rand(2,3)', rand(2,3,4)) == Array{Float64} - @test storage_type([1,2,3], 4:5) == AbstractVector{Int} + @test storage_typejoin(rand(2), rand(Float32, 2)) == Vector{<:Any} + @test storage_typejoin(rand(2), rand(2,3)', rand(2,3,4)) == Array{Float64} + @test storage_typejoin([1,2,3], 4:5) == AbstractVector{Int} end From d8c176198116f5e4e32b5aebaeece6c2ea5f0c8c Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 29 Apr 2020 10:42:06 +0200 Subject: [PATCH 11/15] doc typos --- src/batched/batchedmul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 406f86996..e159d77be 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -192,13 +192,13 @@ storage_type(A) = typeof(A) Reduces with `Base.promote_typejoin`, in order that this conveys useful information for dispatching to BLAS. It does not tell you what container to allocate: ``` -julia> storage_type(rand(2), rand(Float32, 2)) +julia> storage_typejoin(rand(2), rand(Float32, 2)) Array{T,1} where T julia> eltype(ans) <: LinearAlgebra.BlasFloat false -julia> storage_type(rand(2), rand(2,3), rand(2,3,4)) +julia> storage_typejoin(rand(2), rand(2,3), rand(2,3,4)) Array{Float64,N} where N ``` """ From 18deacd5c034f49d38ebdba2cb3cfe4ddf22e90b Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Thu, 2 Jul 2020 14:21:28 +0200 Subject: [PATCH 12/15] multi-thread loop + single-thread BLAS --- Project.toml | 2 ++ src/gemm.jl | 22 +++++++++++++++------- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index c5b01f74c..3162b01cd 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.7.5" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -10,6 +11,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] +Compat = "3.13" Requires = "0.5, 1.0" julia = "1.3" diff --git a/src/gemm.jl b/src/gemm.jl index 91c7ce984..440aaae37 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -4,6 +4,8 @@ using LinearAlgebra using LinearAlgebra.BLAS: libblas, BlasInt, @blasfunc +using Compat: get_num_threads, set_num_threads + """ gemm!() @@ -89,22 +91,28 @@ for (gemm, elt) in gemm_datatype_mappings strB = size(B, 3) == 1 ? 0 : Base.stride(B, 3) strC = Base.stride(C, 3) - for k in 1:size(A, 3) + old_threads = get_num_threads() + set_num_threads(1) + + Threads.@threads for k in 1:size(C, 3) + + ptrAk = ptrA + (k-1) * strA * sizeof($elt) + ptrBk = ptrB + (k-1) * strB * sizeof($elt) + ptrCk = ptrC + (k-1) * strC * sizeof($elt) + ccall((@blasfunc($(gemm)), libblas), Nothing, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}, Ptr{$elt}, Ref{BlasInt}, Ref{$elt}, Ptr{$elt}, Ref{BlasInt}), transA, transB, m, n, - ka, alpha, ptrA, max(1,Base.stride(A,2)), - ptrB, max(1,Base.stride(B,2)), beta, ptrC, + ka, alpha, ptrAk, max(1,Base.stride(A,2)), + ptrBk, max(1,Base.stride(B,2)), beta, ptrCk, max(1,Base.stride(C,2))) - - ptrA += strA * sizeof($elt) - ptrB += strB * sizeof($elt) - ptrC += Base.stride(C, 3) * sizeof($elt) end + set_num_threads(old_threads) + C end end From b51cc498a6d9388381cbb53379f7e91527de4e6e Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 24 Oct 2020 13:18:04 +0200 Subject: [PATCH 13/15] simplify thanks to Compat.jl, and julia 1.3+ --- Project.toml | 2 +- src/batched/batchedmul.jl | 31 ++++++------------------------- 2 files changed, 7 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 3162b01cd..247a47089 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -Compat = "3.13" +Compat = "3.14" Requires = "0.5, 1.0" julia = "1.3" diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index e159d77be..fa147d63d 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -141,28 +141,18 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST @eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB, α::Number=one(T), β::Number=zero(T)) where {T} + size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C")) size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C")) @debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C) + Abase, Bbase = _unbatch(A), _unbatch(B) sA, oA = size(A,3) == 1 ? (0,1) : (1,0) sB, oB = size(B,3) == 1 ? (0,1) : (1,0) - if VERSION >= v"1.3" - @inbounds for k in 1:size(C,3) - @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), α, β) - end - elseif α==1 && β==0 - @inbounds for k in 1:size(C,3) - @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB])) - end - else - @debug "since there is no 5-arg mul!, calling C1 .= α .* (A1 * B1) .+ β .* C" α β - @inbounds for k in 1:size(C,3) - @views C[:,:,k] .= α .* $fA(Abase[:,:,k*sA+oA]) * $fB(Bbase[:,:,k*sB+oB]) .+ β .* C[:,:,k] - end + @inbounds for k in 1:size(C,3) + @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), α, β) end - C end @@ -216,9 +206,6 @@ strided-ness, and hence also return `is_strided(parent(A))`. This correctly handles things like `NamedDimsArray` wihch don't alter indexing. However, it's a little pessimistic in that e.g. a `view` of such a container will return `false`, even in cases where the same `view` of `parent(A)` would be a `StridedArray`. - -`A::Transpose` doesn't currently define `strides`, until that's fixed this returns `false`. -The PR to fix that only defines `strides(::Adjoint{T})` for `T<:Real`, so this will follow. """ is_strided(A::StridedArray) = true is_strided(A) = false @@ -237,13 +224,7 @@ end is_strided(A::BatchedAdjoint) = eltype(A) <: Real && is_strided(parent(A)) is_strided(A::BatchedTranspose) = is_strided(parent(A)) -if hasmethod(Base.strides, Tuple{LinearAlgebra.Transpose}) - # https://github.com/JuliaLang/julia/pull/29135 - is_strided(A::LinearAlgebra.Transpose) = is_strided(parent(A)) - is_strided(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A)) -else - is_strided(A::LinearAlgebra.Transpose) = false - is_strided(A::LinearAlgebra.Adjoint) = false -end +is_strided(A::LinearAlgebra.Transpose) = is_strided(parent(A)) +is_strided(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A)) are_strided(As...) = mapfoldl(is_strided, &, As; init=true) From c8a1fee21d037b03716ba3745d4d03b01b626b75 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 24 Oct 2020 13:19:55 +0200 Subject: [PATCH 14/15] type parameter tweaks --- src/batched/batchedmul.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index fa147d63d..641b57421 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -85,12 +85,12 @@ end _batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β) -_batched_mul!(DT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} = +_batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} = _batched_try_gemm!(DT, C, A, B, α, β) -function _batched_try_gemm!(DT::Type{<:DenseArray{T}}, C, A, B, α::Number, β::Number) where {T<:BlasFloat} +function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} - alpha, beta = promote(α, β, zero(T)) # trick from https://github.com/JuliaLang/julia/pull/33229 + alpha, beta = promote(α, β, zero(T)) alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β) are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β) From 676d166c95b21132442544b99401e445b894acc4 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 24 Oct 2020 13:52:19 +0200 Subject: [PATCH 15/15] use adjoint not transpose, + doc tweaks --- src/batched/batchedmul.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 641b57421..c6c1bf2de 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -19,8 +19,8 @@ and similarly `batched_adjoint`. Other permutations are also handled by BLAS, provided that the batch index `k` is not the first dimension of the underlying array. Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine. -However `PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS, -and will thus be copied as this is faster than the fallback method `batched_mul_generic!`. +However `A = PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS, +since `stride(A,3) == 1`. This be copied, as doing so is faster than `batched_mul_generic!`. Both this `copy` and `batched_mul_generic!` produce `@debug` messages, and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them. @@ -74,8 +74,9 @@ for `X ∈ [A,B,C]`, either `strides(X,1)==1` or `strides(X,2)==1`, the latter m be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`. Unlike `batched_mul` this will never make a copy. -For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen, -and in this case `stride(A::BatchedAdjoint,2) == 1` is not optional. +For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen. +In this case the strided accepted by BLAS are more restricted, if `stride(C,1)==1` then +only `stride(AorB::BatchedAdjoint,2) == 1` is accepted. """ function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}, α::Number=one(T), β::Number=zero(T)) where {T} @@ -97,8 +98,8 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where { if Base.stride(C,1) == 1 elseif Base.stride(C,2) == 1 - @debug "transposing C = A * B into Cᵀ = Bᵀ * Aᵀ" size(C) strides(C) - return batched_mul!(batched_transpose(C), batched_transpose(B), batched_transpose(A), α, β) + @debug "transforming C = A * B into C' = B' * A'" size(C) strides(C) + return batched_mul!(batched_adjoint(C), batched_adjoint(B), batched_adjoint(A), α, β) else return batched_mul_generic!(C, A, B, α, β) end