Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to batched_mul, including PermutedDimsArray #187

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# This file is machine-generated - editing it directly is not advised

[[ArrayLayouts]]
deps = ["FillArrays", "LinearAlgebra"]
git-tree-sha1 = "5a57a6158c1d340635a89d19beb34b0f325a4431"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
version = "0.2.5"

[[BinaryProvider]]
deps = ["Libdl", "SHA"]
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.8"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "51cc2f9bc4eb9c6c0e81ec2f779d1085583cc956"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.8.7"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.6"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ArrayLayouts = "0.2.5"
BinaryProvider = "0.5"
Requires = "0.5, 1.0"
julia = "1"
Expand Down
41 changes: 35 additions & 6 deletions src/batched/batchedadjtrans.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LinearAlgebra
using LinearAlgebra, ArrayLayouts

import Base: -

_batched_doc = """
Expand All @@ -10,10 +11,13 @@ Equivalent to applying `transpose` or `adjoint` to each matrix `A[:,:,k]`.
These exist to control how `batched_mul` behaves,
as it operated on such matrix slices of an array with `ndims(A)==3`.

BatchedTranspose{T, N, S} <: AbstractBatchedMatrix{T, N}
BatchedAdjoint{T, N, S}
For arrays of real numbers, `batched_transpose(A) == PermutedDimsArray(A, (2,1,3))`,
which is a more widely-supported wrapper, and also understood by `batched_mul`.

BatchedTranspose{T, S} <: AbstractBatchedMatrix{T, 3}
BatchedAdjoint{T, S}

Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose`.
Lazy wrappers analogous to `Transpose` and `Adjoint`, returned by `batched_transpose` etc.
"""

@doc _batched_doc
Expand All @@ -36,6 +40,13 @@ end
batched_adjoint(A::AbstractArray{T, 3}) where T = BatchedAdjoint(A)
batched_adjoint(A::BatchedAdjoint) = A.parent

batched_adjoint(A::BatchedTranspose{<:Real}) = A.parent
batched_transpose(A::BatchedAdjoint{<:Real}) = A.parent
batched_adjoint(A::PermutedDimsArray{<:Real,3,(2,1,3)}) = A.parent
batched_transpose(A::PermutedDimsArray{<:Number,3,(2,1,3)}) = A.parent
# if you can't unwrap, put BatchedAdjoint outside (for dispatch):
batched_transpose(A::BatchedAdjoint{<:Complex}) = BatchedAdjoint(BatchedTranspose(A.parent))

BatchedAdjoint(A) = BatchedAdjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
BatchedTranspose(A) = BatchedTranspose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)

Expand Down Expand Up @@ -65,6 +76,24 @@ Base.parent(A::BatchedAdjOrTrans) = A.parent
(-)(A::BatchedAdjoint) = BatchedAdjoint( -A.parent)
(-)(A::BatchedTranspose) = BatchedTranspose(-A.parent)

Base.copy(A::BatchedTranspose) = BatchedTranspose(copy(A.parent))
Base.copy(A::BatchedAdjoint) = BatchedAdjoint(copy(A.parent))
# C interface
function Base.strides(A::BatchedAdjOrTrans)
sp = strides(A.parent)
(sp[2], sp[1], sp[3])
end

function Base.stride(A::BatchedAdjOrTrans, d::Integer)
d == 1 && return Base.stride(A.parent, 2)
d == 2 && return Base.stride(A.parent, 1)
Base.stride(A.parent, d)
end

Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} =
Base.unsafe_convert(Ptr{T}, parent(A))

ArrayLayouts.MemoryLayout(::Type{BatchedTranspose{T,S}}) where {T,S} =
ArrayLayouts.permutelayout(MemoryLayout(S), Val((2,1,3)))

ArrayLayouts.MemoryLayout(::Type{BatchedAdjoint{T,S}}) where {T,S} =
ArrayLayouts.permutelayout(ArrayLayouts.conjlayout(T, MemoryLayout(S)), Val((2,1,3)))

192 changes: 166 additions & 26 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,64 @@
# batch-wise matrix multiplication
# wrapper for batched_gemm!

export batched_mul, batched_transpose, batched_adjoint

using LinearAlgebra: BlasFloat, BlasReal

using Base: promote_typejoin

using ArrayLayouts: MemoryLayout, UnitStride, AbstractColumnMajor, ConjLayout, StridedLayout, UnknownLayout, AbstractStridedLayout

const UnitStrideFirst = Union{UnitStride{1}, AbstractColumnMajor}
const MaybeConjStrided = Union{AbstractStridedLayout, ConjLayout{<:AbstractStridedLayout}}

include("./batchedadjtrans.jl")

"""
batched_mul(A, B) -> C

Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`.

Using `batched_transpose(A)` will transpose each `A[:,:,k]`,
and similarly `batched_adjoint(B)` will use `adjoint(B[:,:,k])`.

It will also accept `A` or `B` which are `PermutedDimsArray{T,3}`.
On the CPU, these will still be handled by `BLAS.gemm!` provided `T <: LinearAlgebra.BlasFloat`
and they can be permuted to be column-major. For `T <: Real`, this allows any permutations
so long as `Base.stride(A,3) != 1` and `Base.stride(B,3) != 1`.
(For `T <: Complex`, instead you must have `Base.stride(A,1) == 1 == Base.stride(B,1)`.)

Other cases will fall back to `batched_mul_generic!`, which logs a message via `@debug`.
```
julia> A = PermutedDimsArray(rand(5,4,10), (2,1,3)); size(A)
(4, 5, 10)

julia> strides(A) # this will be absorbed by transposing
(5, 1, 20)

julia> B = PermutedDimsArray(rand(5,10,6), (1,3,2)); size(B)
(5, 6, 10)

julia> strides(B) # this is fine as it is
(1, 50, 5)

julia> ENV["JULIA_DEBUG"] = NNlib;

julia> C = batched_mul(A, B); size(C) # done by batched_gemm!
(4, 6, 10)

julia> A2 = PermutedDimsArray(rand(10,5,4), (3,2,1)); size(A2)
(4, 5, 10)

julia> strides(A2) # this can't be fixed
(50, 10, 1)

julia> C2 = batched_mul(A2, B); size(C2)
┌ Debug: calling fallback method for batched_mul!
│ typeof(A) = PermutedDimsArray{Float64,3,(3, 2, 1),(3, 2, 1),Array{Float64,3}}
│ typeof(B) = PermutedDimsArray{Float64,3,(1, 3, 2),(1, 3, 2),Array{Float64,3}}
│ typeof(C) = Array{Float64,3}
└ @ NNlib ~/.julia/dev/NNlib/src/batched/batchedmul.jl:133
(4, 6, 10)
```
"""
function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch"))
Expand All @@ -17,48 +68,137 @@ function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1
end

"""
batched_mul!(C, A, B) -> C
batched_mul!(C, A, B, α=1, β=0) -> C
batched_mul_generic!(C, A, B, α=1, β=0)

In-place batched matrix multiplication,
equivalent to `mul!(C[:,:,k], A[:,:,k], B[:,:,k])` for all `k`.
"""
function batched_mul! end

_unbatch(A) = A
_unbatch(A::BatchedAdjOrTrans) = A.parent

# batched_gemm!
equivalent to `mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)` for all `k`.

const _GemmFloat = Union{Float64, Float32, ComplexF64, ComplexF32}
The fallback implementation of this literally calls `mul!`,
and hence can only accept `α!=1` or `β!=0` on Julia >= 1.3.
"""
function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3},
α::Number=one(T), β::Number=zero(T)) where {T}
# Use promote_typejoin here to ensure Float64 * Int doesn't go to gemm!
type = promote_typejoin(storage_type(C), promote_typejoin(storage_type(A), storage_type(B)))
_batched_mul!(type, C, memory_layout(C), A, memory_layout(A), B, memory_layout(B), α, β)
C
end

_BATCHED_GEMM_LIST = [
(:(StridedArray{T, 3}), 'N'),
(:(BatchedTranspose{T, <:StridedArray{T, 3}}), 'T'),
(:(BatchedAdjoint{T, <:StridedArray{T, 3}}), 'C')
(:UnitStrideFirst, 'N', :identity),
(:(UnitStride{2}), 'T', :batched_transpose),
(:(ConjLayout{UnitStride{2}}), 'C', :batched_adjoint)
]
for (MA, tA, fA) in _BATCHED_GEMM_LIST, (MB, tB, fB) in _BATCHED_GEMM_LIST

for (TA, transA) in _BATCHED_GEMM_LIST, (TB, transB) in _BATCHED_GEMM_LIST
@eval function batched_mul!(C::StridedArray{T, 3}, A::$TA, B::$TB) where {T<:_GemmFloat}
batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C)
C
@eval function _batched_mul!(::Type{<:Array{T}}, C, ::UnitStrideFirst, A, ::$MA, B, ::$MB,
α::Number, β::Number) where {T<:BlasFloat}
batched_gemm!($tA, $tB, convert(T,α), $fA(A), $fB(B), convert(T,β), C)
end

end

function _batched_mul!(::Type{<:AbstractArray{T}}, C, ::UnitStride{2},
A, ::MaybeConjStrided, B, ::MaybeConjStrided, α::Number, β::Number) where {T<:BlasFloat}
batched_mul!(batched_transpose(C), batched_transpose(B), batched_transpose(A), α, β)
end

# fallback
function _batched_mul!(::Type{<:AbstractArray}, C, ::MemoryLayout, A, ::MemoryLayout, B, ::MemoryLayout,
α::Number, β::Number)
batched_mul_generic!(C, A, B, α, β)
end

# Fallback: only here do we look directly at types BatchedTranspose etc.

_unbatch(A) = A
_unbatch(A::BatchedAdjOrTrans) = A.parent

_BATCHED_LIST = [
(:(AbstractArray{<:Any, 3}), :identity),
(:(BatchedTranspose{<:Any, <:AbstractArray{<:Any, 3}}), :transpose),
(:(BatchedAdjoint{<:Any, <:AbstractArray{<:Any, 3}}), :adjoint)
(:BatchedTranspose, :transpose),
(:BatchedAdjoint, :adjoint),
]
for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST
@eval function batched_mul!(C::AbstractArray{<:Any, 3}, A::$TA, B::$TB)

@eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB,
α::Number=one(T), β::Number=zero(T)) where {T}
axes(A, 3) == axes(B, 3) == axes(C, 3) || throw(DimensionMismatch("batch size mismatch"))
@debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C)
A′, B′ = _unbatch(A), _unbatch(B)
@inbounds for k in axes(C, 3)
@views mul!(C[:,:,k], $fA(A′[:,:,k]), $fB(B′[:,:,k]))
Abase, Bbase = _unbatch(A), _unbatch(B)
if VERSION >= v"1.3"
@inbounds for k in axes(C, 3)
@views mul!(C[:,:,k], $fA(Abase[:,:,k]), $fB(Bbase[:,:,k]), convert(T,α), convert(T,β))
end
else
α==1 && β==0 || throw(ArgumentError("5-arg batched_mul_generic! does not work on Julia < 1.3"))
@inbounds for k in axes(C, 3)
@views mul!(C[:,:,k], $fA(Abase[:,:,k]), $fB(Bbase[:,:,k]))
end
end
C
end

end


"""
storage_type(A)

Removes all wrappers to return the `Array` or `CuArray` (or whatever) type within.
```
julia> view(reshape(ones(10)',2,5),:, 3:4) |> storage_type
Array{Float64,1}

julia> reshape(sparse(rand(10)), 5,2) |> storage_type
SparseVector{Float64,Int64}
```
"""
function storage_type(A::AbstractArray)
P = parent(A)
typeof(A) === typeof(P) ? typeof(A) : storage_type(P)
end
storage_type(A) = typeof(A)


"""
memory_layout(A)

This is usually `ArrayLayouts.MemoryLayout(A)`.

The exception is that, for wrapper types which that package does not know about,
and for which `parent(A)` has any `AbstractStridedLayout`,
it will use `strides(A)` to return `UnitStride{1}()`, `UnitStride{2}()`, or `StridedLayout()`.
(And if parent(A) is conjugated, then `ConjLayout{UnitStride{1}}()` etc.)
"""
memory_layout(A) = _memory_layout(A, MemoryLayout(A))

_memory_layout(A, M::AbstractStridedLayout) = M
_memory_layout(A, M::ConjLayout{<:AbstractStridedLayout}) = M

function _memory_layout(A, ::MemoryLayout)
P = parent(A)
typeof(A) === typeof(P) && return UnknownLayout()
# Now it's a wrapper. If it contains something strided,
# then we go by the strides of A, since those of P may be re-ordered.
if MemoryLayout(P) isa AbstractStridedLayout
@debug "using runtime strides" typeof(A) strides(A)
return _find_unit_stride(A)
elseif MemoryLayout(P) isa ConjLayout{<:AbstractStridedLayout}
@debug "using runtime strides, parent is conjugated" typeof(A) strides(A)
return ArrayLayouts.conjlayout(eltype(A), _find_unit_stride(A))
else
return UnknownLayout()
end
end

function _find_unit_stride(A)
s = Base.strides(A)
if s[1] == 1
return UnitStride{1}()
elseif ndims(A) >= 2 && s[2] == 1
return UnitStride{2}()
else
return StridedLayout()
end
end
6 changes: 3 additions & 3 deletions src/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ for (gemm, elt) in gemm_datatype_mappings
ptrB, max(1,Base.stride(B,2)), beta, ptrC,
max(1,Base.stride(C,2)))

ptrA += size(A, 1) * size(A, 2) * sizeof($elt)
ptrB += size(B, 1) * size(B, 2) * sizeof($elt)
ptrC += size(C, 1) * size(C, 2) * sizeof($elt)
ptrA += Base.stride(A, 3) * sizeof($elt)
ptrB += Base.stride(B, 3) * sizeof($elt)
ptrC += Base.stride(C, 3) * sizeof($elt)
end

C
Expand Down
Loading