Skip to content

Commit

Permalink
now memory_layout() checks strides if it must
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Apr 3, 2020
1 parent ae50fce commit fa103ee
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 87 deletions.
99 changes: 26 additions & 73 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ julia> strides(A2) # this can't be fixed
(50, 10, 1)
julia> C2 = batched_mul(A2, B); size(C2)
┌ Debug: couldn't re-arrange strides for batched_gemm!
│ strides(A) = (50, 10, 1)
│ strides(B) = (1, 50, 5)
│ strides(C) = (1, 4, 24)
└ @ NNlib ~/.julia/dev/NNlib/src/batched/batchedmul.jl:112
┌ Debug: calling fallback method for batched_mul!
│ typeof(A) = PermutedDimsArray{Float64,3,(3, 2, 1),(3, 2, 1),Array{Float64,3}}
│ typeof(B) = PermutedDimsArray{Float64,3,(1, 3, 2),(1, 3, 2),Array{Float64,3}}
Expand Down Expand Up @@ -82,85 +77,32 @@ and hence can only accept `α!=1` or `β!=0` on Julia >= 1.3.
"""
function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3},
α::Number=one(T), β::Number=zero(T)) where {T}

# Use promote_typejoin here to ensure Float64 * Int doesn't go to gemm!
type = promote_typejoin(storage_type(C), promote_typejoin(storage_type(A), storage_type(B)))

_batched_mul!(type, C, memory_layout(C), A, memory_layout(A), B, memory_layout(B), α, β)
C
end

# Dispatch on storage type: CuArrays can define _batched_mul!(::CuArray, ...)
# Dispatch on ArrayLayouts traits: decide where you need 'T' etc.

# _BATCHED_GEMM_LIST = [
# (:UnitStrideFirst, 'N', :identity, :(UnitStride{2})),
# (:(UnitStride{2}), 'T', :batched_transpose, :UnitStrideFirst),
# (:(ConjLayout{UnitStride{2}}), 'C', :batched_adjoint, :Nothing)
# ]
# for (TA, tA, fA, revTA) in _BATCHED_GEMM_LIST, (TB, tB, fB, revTB) in _BATCHED_GEMM_LIST
_BATCHED_GEMM_LIST = [
(:UnitStrideFirst, 'N', :identity),
(:(UnitStride{2}), 'T', :batched_transpose),
(:(ConjLayout{UnitStride{2}}), 'C', :batched_adjoint)
]
for (MA, tA, fA) in _BATCHED_GEMM_LIST, (MB, tB, fB) in _BATCHED_GEMM_LIST

# Path 1, e.g. C isa Array, batched_transpose(A) or PermutedDimsArray(B, (3,1,2)) both need 'T'
@eval function _batched_mul!(::Type{<:Array{T}},
C, ::UnitStrideFirst, A, ::$MA, B, ::$MB,
@eval function _batched_mul!(::Type{<:Array{T}}, C, ::UnitStrideFirst, A, ::$MA, B, ::$MB,
α::Number, β::Number) where {T<:BlasFloat}

batched_gemm!($tA, $tB, convert(T,α), $fA(A), $fB(B), convert(T,β), C)
end

# # Path 2, C = batched_transpose(Array), so transpose the entire equation
# if tA != 'C' && tB != 'C'
# not_tA = tA == 'T' ? 'N' : 'T'
# not_tB = tB == 'T' ? 'N' : 'T'
# @eval function _batched_mul!(::Type{<:Array{T}}, C, ::UnitStride{2}, A, ::$MA, B, ::$MB, α::Number, β::Number) where {T<:BlasFloat}
# @warn "this is broken!"
# @debug "transposing C, and thus A, B to compensate..." size(A) size(B) size(C) strides(A) strides(B) strides(C)
# batched_gemm!($not_tB, $not_tA, convert(T,α), $fB(B), $fA(A), convert(T,β), batched_transpose(C))
# end
# end
end

# Path 3, use runtime strides. Does not catch ConjLayout{StridedLayout}()
# function _batched_mul!(TC::Type{<:AbstractArray{T}}, C, ::AbstractStridedLayout, A, ::AbstractStridedLayout, B, ::AbstractStridedLayout, B, α::Number, β::Number) where {T<:BlasFloat}
function _batched_mul!(TC::Type{<:AbstractArray{T}},
C, MC::UnitStrideFirst, A, ::AbstractStridedLayout, B, ::AbstractStridedLayout,
α::Number, β::Number) where {T<:BlasFloat}

@debug "using runtime strides" strides(A) strides(B) strides(C)

MA = Base.stride(A,1) == 1 ? UnitStride{1}() :
Base.stride(A,2) == 1 ? UnitStride{2}() :
return batched_mul_generic!(C,A,B,α,β)

MB = Base.stride(B,1) == 1 ? UnitStride{1}() :
Base.stride(B,2) == 1 ? UnitStride{2}() :
return batched_mul_generic!(C,A,B,α,β)

# MC = Base.stride(C,1) == 1 ? UnitStride{1}() :
# Base.stride(C,2) == 1 ? UnitStride{2}() :
# return batched_mul_generic!(C,A,B,α,β)

# Useless, as batched_transpose would make ConjLayout{StridedLayout}()
# MA = A isa BatchedAdjoint ? ArrayLayouts.conjlayout(T, MA) : MA
# MB = B isa BatchedAdjoint ? ArrayLayouts.conjlayout(T, MB) : MB

_batched_mul!(TC, C, MC, A, MA, B, MB, α, β)
end

# Path 4, anything else goes directly to the fallback
function _batched_mul!(::Type{<:AbstractArray},
C, ::MemoryLayout, A, ::MemoryLayout, B, ::MemoryLayout, α::Number, β::Number)

function _batched_mul!(::Type{<:AbstractArray}, C, ::MemoryLayout, A, ::MemoryLayout, B, ::MemoryLayout,
α::Number, β::Number)
batched_mul_generic!(C, A, B, α, β)
end

# Fallback: only here do we look directly at BatchedTranspose etc.
# Fallback: only here do we look directly at types BatchedTranspose etc.

_unbatch(A) = A
_unbatch(A::BatchedAdjOrTrans) = A.parent
Expand All @@ -174,11 +116,9 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST

@eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB,
α::Number=one(T), β::Number=zero(T)) where {T}

axes(A, 3) == axes(B, 3) == axes(C, 3) || throw(DimensionMismatch("batch size mismatch"))
@debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C)
Abase, Bbase = _unbatch(A), _unbatch(B)

if VERSION >= v"1.3"
@inbounds for k in axes(C, 3)
@views mul!(C[:,:,k], $fA(Abase[:,:,k]), $fB(Bbase[:,:,k]), convert(T,α), convert(T,β))
Expand All @@ -189,7 +129,6 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST
@views mul!(C[:,:,k], $fA(Abase[:,:,k]), $fB(Bbase[:,:,k]))
end
end

C
end

Expand Down Expand Up @@ -221,8 +160,9 @@ storage_type(A) = typeof(A)
This is usually `ArrayLayouts.MemoryLayout(A)`.
The exception is that, for wrapper types which that package does not know about,
and for which `parent(A)` has any `AbstractStridedLayout`, it returns `StridedLayout()`.
(And if parent(A) is conjugated, then `ConjLayout{StridedLayout}()`.)
and for which `parent(A)` has any `AbstractStridedLayout`,
it will use `strides(A)` to return `UnitStride{1}()`, `UnitStride{2}()`, or `StridedLayout()`.
(And if parent(A) is conjugated, then `ConjLayout{UnitStride{1}}()` etc.)
"""
memory_layout(A) = _memory_layout(A, MemoryLayout(A))

Expand All @@ -231,14 +171,27 @@ _memory_layout(A, M::ConjLayout{<:AbstractStridedLayout}) = M

function _memory_layout(A, ::MemoryLayout)
P = parent(A)
if typeof(A) === typeof(P)
UnknownLayout()
elseif MemoryLayout(P) isa AbstractStridedLayout
StridedLayout()
typeof(A) === typeof(P) && return UnknownLayout()
# Now it's a wrapper. If it contains something strided,
# then we go by the strides of A, since those of P may be re-ordered.
if MemoryLayout(P) isa AbstractStridedLayout
@debug "using runtime strides" typeof(A) strides(A)
return _find_unit_stride(A)
elseif MemoryLayout(P) isa ConjLayout{<:AbstractStridedLayout}
ConjLayout{StridedLayout}()
@debug "using runtime strides, parent is conjugated" typeof(A) strides(A)
return ArrayLayouts.conjlayout(eltype(A), _find_unit_stride(A))
else
UnknownLayout()
return UnknownLayout()
end
end

function _find_unit_stride(A)
s = Base.strides(A)
if s[1] == 1
return UnitStride{1}()
elseif ndims(A) >= 2 && s[2] == 1
return UnitStride{2}()
else
return StridedLayout()
end
end
17 changes: 3 additions & 14 deletions test/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ Base.unsafe_convert(::Type{Ptr{T}}, A::TestWrap{T}) where {T} =
@test memory_layout(PermutedDimsArray(A, (2,1,3))) == UnitStride{2}()
@test memory_layout(PermutedDimsArray(A, (2,3,1))) == UnitStride{3}()

@test memory_layout(TestWrap(A)) == StridedLayout()
@test memory_layout(TestWrap(batched_transpose(A))) == StridedLayout()
@test memory_layout(TestWrap(batched_adjoint(A))) == ConjLayout{StridedLayout}()
@test memory_layout(TestWrap(A)) == UnitStride{1}()
@test memory_layout(TestWrap(batched_transpose(A))) == UnitStride{2}()
@test memory_layout(TestWrap(batched_adjoint(A))) == ConjLayout{UnitStride{2}}()
@test stride(TestWrap(A),3) == stride(A,3)

@test storage_type(TestWrap(A)) == typeof(A)
Expand Down Expand Up @@ -182,17 +182,6 @@ end
end

end
# @testset "batched_mul! with permuted output" begin # this is broken!

# A = rand(3,3,3)
# B = rand(3,3,3)
# C = PermutedDimsArray(zeros(3,3,3), (2,1,3))
# @test_broken batched_mul(A, B) ≈ batched_mul!(C, B, A)

# B = batched_adjoint(rand(3,3,3))
# C = PermutedDimsArray(zeros(3,3,3), (3,1,2))
# @test_broken batched_mul(A, B) ≈ batched_mul!(C, B, A)

# end
end
end

0 comments on commit fa103ee

Please sign in to comment.