Skip to content

Commit

Permalink
Merge pull request #191 from mcabbott/fix2
Browse files Browse the repository at this point in the history
Allow batched_mul to work through PermutedDimsArray, II
  • Loading branch information
CarloLucibello authored Nov 11, 2020
2 parents d3489ea + 676d166 commit 9780c29
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 45 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.6"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Compat = "3.14"
Requires = "0.5, 1.0"
julia = "1.3"

Expand Down
35 changes: 29 additions & 6 deletions src/batched/batchedadjtrans.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LinearAlgebra

import Base: -

_batched_doc = """
Expand All @@ -8,12 +9,15 @@ _batched_doc = """
Equivalent to applying `transpose` or `adjoint` to each matrix `A[:,:,k]`.
These exist to control how `batched_mul` behaves,
as it operated on such matrix slices of an array with `ndims(A)==3`.
as it operates on such matrix slices of an array with `ndims(A)==3`.
`PermutedDimsArray(A, (2,1,3))` is equivalent to `batched_transpose(A)`,
and is also understood by `batched_mul` (and more widely supported elsewhere).
BatchedTranspose{T, N, S} <: AbstractBatchedMatrix{T, N}
BatchedAdjoint{T, N, S}
BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}
Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose`.
Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose` etc.
"""

@doc _batched_doc
Expand All @@ -36,6 +40,13 @@ end
batched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A)
batched_adjoint(A::BatchedAdjoint) = A.parent

batched_adjoint(A::BatchedTranspose{<:Real}) = A.parent
batched_transpose(A::BatchedAdjoint{<:Real}) = A.parent
batched_adjoint(A::PermutedDimsArray{<:Real,3,(2,1,3)}) = A.parent
batched_transpose(A::PermutedDimsArray{<:Number,3,(2,1,3)}) = A.parent
# if you can't unwrap, put BatchedAdjoint outside (for dispatch):
batched_transpose(A::BatchedAdjoint{<:Complex}) = BatchedAdjoint(BatchedTranspose(A.parent))

BatchedAdjoint(A) = BatchedAdjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
BatchedTranspose(A) = BatchedTranspose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)

Expand Down Expand Up @@ -65,6 +76,18 @@ Base.parent(A::BatchedAdjOrTrans) = A.parent
(-)(A::BatchedAdjoint) = BatchedAdjoint( -A.parent)
(-)(A::BatchedTranspose) = BatchedTranspose(-A.parent)

Base.copy(A::BatchedTranspose) = BatchedTranspose(copy(A.parent))
Base.copy(A::BatchedAdjoint) = BatchedAdjoint(copy(A.parent))
# C interface
function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})
sp = strides(A.parent)
(sp[2], sp[1], sp[3])
end

function Base.stride(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}}, d::Integer)
d == 1 && return Base.stride(A.parent, 2)
d == 2 && return Base.stride(A.parent, 1)
Base.stride(A.parent, d)
end

Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} =
Base.unsafe_convert(Ptr{T}, parent(A))

225 changes: 196 additions & 29 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
@@ -1,64 +1,231 @@
# batch-wise matrix multiplication
# wrapper for batched_gemm!

export batched_mul, batched_transpose, batched_adjoint

include("./batchedadjtrans.jl")

using LinearAlgebra: BlasFloat, Transpose, Adjoint

_unbatch(A) = A
_unbatch(A::BatchedAdjOrTrans) = parent(A)

"""
batched_mul(A, B) -> C
Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`.
If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.
To transpose each matrix apply `batched_transpose` to the array,
and similarly `batched_adjoint`. Other permutations are also handled by BLAS,
provided that the batch index `k` is not the first dimension of the underlying array.
Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine.
However `A = PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS,
since `stride(A,3) == 1`. This be copied, as doing so is faster than `batched_mul_generic!`.
Both this `copy` and `batched_mul_generic!` produce `@debug` messages,
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them.
"""
function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch"))
T = promote_type(T1, T2)
C = similar(A, T, (axes(A, 1), axes(B, 2), axes(A, 3)))
size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 ||
throw(DimensionMismatch("batch size mismatch: A != B"))
_batched_mul(storage_typejoin(A, B), A, B)
end

function _batched_mul(::Type, A, B)
T = promote_type(eltype(A), eltype(B))
C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))))
batched_mul!(C, A, B)
C
end
function _batched_mul(::Type{<:DenseArray{T}}, A, B) where {T<:BlasFloat}
C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3))))
batched_mul!(C, _copy_if_faster(A), _copy_if_faster(B))
C
end

function _copy_if_faster(X::AbstractArray{<:Number, 3})
is_strided(X) || return X
if Base.stride(X, 3) == 1 && Base.stride(X, 1) != 1
@debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(X)
return copy(X)
end
X
end
function _copy_if_faster(X::BatchedAdjoint{<:Complex})
Xbase = _unbatch(X)
is_strided(Xbase) || return X
if Base.stride(Xbase, 1) != 1
@debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(_unbatch(X))
return copy(X) # or batched_adjoint(copy(Xbase)), may be better on GPU?
end
X
end

"""
batched_mul!(C, A, B) -> C
batched_mul!(C, A, B, α=1, β=0)
In-place batched matrix multiplication, equivalent to
`mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)` for all `k`.
If `size(B,3) == 1` then every batch uses `B[:,:,1]` instead.
In-place batched matrix multiplication,
equivalent to `mul!(C[:,:,k], A[:,:,k], B[:,:,k])` for all `k`.
This will call `batched_gemm!` whenever possible. For real arrays this means that,
for `X ∈ [A,B,C]`, either `strides(X,1)==1` or `strides(X,2)==1`, the latter may
be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`.
Unlike `batched_mul` this will never make a copy.
For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen.
In this case the strided accepted by BLAS are more restricted, if `stride(C,1)==1` then
only `stride(AorB::BatchedAdjoint,2) == 1` is accepted.
"""
function batched_mul! end
function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3},
α::Number=one(T), β::Number=zero(T)) where {T}
_batched_mul!(storage_typejoin(C,A,B), C, A, B, α, β)
C
end

_unbatch(A) = A
_unbatch(A::BatchedAdjOrTrans) = A.parent
_batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β)

# batched_gemm!
_batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} =
_batched_try_gemm!(DT, C, A, B, α, β)

const _GemmFloat = Union{Float64, Float32, ComplexF64, ComplexF32}
function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat}

_BATCHED_GEMM_LIST = [
(:(StridedArray{T, 3}), 'N'),
(:(BatchedTranspose{T, <:StridedArray{T, 3}}), 'T'),
(:(BatchedAdjoint{T, <:StridedArray{T, 3}}), 'C')
]
alpha, beta = promote(α, β, zero(T))
alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β)

for (TA, transA) in _BATCHED_GEMM_LIST, (TB, transB) in _BATCHED_GEMM_LIST
@eval function batched_mul!(C::StridedArray{T, 3}, A::$TA, B::$TB) where {T<:_GemmFloat}
batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C)
C
are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β)

if Base.stride(C,1) == 1
elseif Base.stride(C,2) == 1
@debug "transforming C = A * B into C' = B' * A'" size(C) strides(C)
return batched_mul!(batched_adjoint(C), batched_adjoint(B), batched_adjoint(A), α, β)
else
return batched_mul_generic!(C, A, B, α, β)
end

blasA, transA = if A isa BatchedAdjoint && T <: Complex
Base.stride(parent(A),1) == 1 || return batched_mul_generic!(C, A, B, α, β)
parent(A), 'C'
elseif Base.stride(A,1) == 1
A, 'N'
elseif Base.stride(A,2) == 1
batched_transpose(A), 'T'
else
return batched_mul_generic!(C, A, B, α, β)
end

blasB, transB = if B isa BatchedAdjoint && T <: Complex
Base.stride(parent(B),1) == 1 || return batched_mul_generic!(C, A, B, α, β)
parent(B), 'C'
elseif Base.stride(B,1) == 1
B, 'N'
elseif Base.stride(B,2) == 1
batched_transpose(B), 'T'
else
return batched_mul_generic!(C, A, B, α, β)
end

_batched_gemm!(DT, transA, transB, alpha, blasA, blasB, beta, C)
C
end

# fallback
_batched_gemm!(::Type{<:Array}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =
batched_gemm!(transA, transB, α, A, B, β, C)

_BATCHED_LIST = [
(:(AbstractArray{<:Any, 3}), :identity),
(:(BatchedTranspose{<:Any, <:AbstractArray{<:Any, 3}}), :transpose),
(:(BatchedAdjoint{<:Any, <:AbstractArray{<:Any, 3}}), :adjoint)
(:BatchedTranspose, :transpose),
(:BatchedAdjoint, :adjoint),
]
for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST
@eval function batched_mul!(C::AbstractArray{<:Any, 3}, A::$TA, B::$TB)
axes(A, 3) == axes(B, 3) == axes(C, 3) || throw(DimensionMismatch("batch size mismatch"))

@eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB,
α::Number=one(T), β::Number=zero(T)) where {T}

size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C"))
size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C"))
@debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C)
A′, B′ = _unbatch(A), _unbatch(B)
@inbounds for k in axes(C, 3)
@views mul!(C[:,:,k], $fA(A′[:,:,k]), $fB(B′[:,:,k]))

Abase, Bbase = _unbatch(A), _unbatch(B)
sA, oA = size(A,3) == 1 ? (0,1) : (1,0)
sB, oB = size(B,3) == 1 ? (0,1) : (1,0)

@inbounds for k in 1:size(C,3)
@views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), α, β)
end
C
end

end

"""
storage_type(A) -> Type
Removes all wrappers to return the `Array` or `CuArray` (or whatever) type within.
```
julia> view(reshape(ones(10)',2,5),:, 3:4) |> storage_type
Array{Float64,1}
julia> reshape(sparse(rand(10)), 5,2) |> storage_type
SparseVector{Float64,Int64}
```
"""
function storage_type(A::AbstractArray)
P = parent(A)
typeof(A) === typeof(P) ? typeof(A) : storage_type(P)
end
storage_type(A) = typeof(A)

"""
storage_typejoin(A, B, C, ...) -> Type
Reduces with `Base.promote_typejoin`, in order that this conveys useful information
for dispatching to BLAS. It does not tell you what container to allocate:
```
julia> storage_typejoin(rand(2), rand(Float32, 2))
Array{T,1} where T
julia> eltype(ans) <: LinearAlgebra.BlasFloat
false
julia> storage_typejoin(rand(2), rand(2,3), rand(2,3,4))
Array{Float64,N} where N
```
"""
storage_typejoin(A, Bs...) = Base.promote_typejoin(storage_type(A), storage_typejoin(Bs...))
storage_typejoin(A) = storage_type(A)

"""
is_strided(A::AbstractArray) -> Bool
This generalises `A isa StridedArray` to treat wrappers like `A::PermutedDimsArray`,
for which it returns `is_strided(parent(A))`.
Other wrappers (defined outside Base, LinearAlgebra) are assumed not to break
strided-ness, and hence also return `is_strided(parent(A))`.
This correctly handles things like `NamedDimsArray` wihch don't alter indexing.
However, it's a little pessimistic in that e.g. a `view` of such a container will return
`false`, even in cases where the same `view` of `parent(A)` would be a `StridedArray`.
"""
is_strided(A::StridedArray) = true
is_strided(A) = false
function is_strided(A::AbstractArray)
M = parentmodule(typeof(A))
if parent(A) === A # SparseMatrix, StaticArray, etc
false
elseif M === Base || M === Core || M ===LinearAlgebra
# bad reshapes, etc, plus Diagonal, UpperTriangular, etc.
false
else
is_strided(parent(A)) # PermutedDimsArray, NamedDimsArray
end
end

is_strided(A::BatchedAdjoint) = eltype(A) <: Real && is_strided(parent(A))
is_strided(A::BatchedTranspose) = is_strided(parent(A))

is_strided(A::LinearAlgebra.Transpose) = is_strided(parent(A))
is_strided(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A))

are_strided(As...) = mapfoldl(is_strided, &, As; init=true)
Loading

0 comments on commit 9780c29

Please sign in to comment.