Skip to content

Commit

Permalink
use _batched_gemm and storage_type from FluxML/NNlib.jl#191
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Apr 3, 2020
1 parent 8bc792b commit 0c6f5c6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 114 deletions.
105 changes: 4 additions & 101 deletions src/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,109 +32,12 @@ end


# Batched matrix multiplication
# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191

# This method has a slightly tighter signature than the one in NNlib, all same eltype.
function NNlib.batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{T,3}, B::AbstractArray{T,3}) where {T<:CUBLAS.CublasFloat}
if is_strided_cu(A) && is_strided_cu(B) && is_strided_cu(C)
# Data is on GPU, and it's safe to call strides(A). gemm_strided_batched may be legal.
batched_try_gemm!(C, A, B)

elseif is_strided_cu(A) || is_strided_cu(B) || is_strided_cu(C)
# This is hopeless, but best option is the fallback
@debug "weird mix of CPU + GPU?"
NNlib.batched_mul_generic!(C, A, B)

else
# All cases for CPU gemm! will come through here, is_strided_cu(A) compiles away:
NNlib.batched_mul_cpu!(C, A, B)
end
end

const batched_gemm_args = [
(:(AbstractArray{T, 3}), 'N', :identity),
(:(NNlib.BatchedTranspose{T}), 'T', :batched_transpose),
(:(NNlib.BatchedAdjoint{T}), 'C', :batched_adjoint)
]

using NNlib: batched_mul!, BatchedTranspose, BatchedAdjoint, batched_transpose, batched_adjoint
using NNlib: _unbatch, _perm12

for (TA, transA, fA) in batched_gemm_args, (TB, transB, fB) in batched_gemm_args
@eval function batched_try_gemm!(C::AbstractArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat}

Abase, Bbase = _unbatch(A), _unbatch(B)

# Best case, we can call batched_gemm! immediately:
if Base.stride(Abase,1) == Base.stride(Bbase,1) == Base.stride(C,1) == 1
CuArrays.CUBLAS.gemm_strided_batched!($transA, $transB, one(T), Abase, Bbase, zero(T), C)

# Second-best, can we fix it by Perm.ing the base, and adjusing 'T' label?
# But only if we won't produce BatchedTranspose(BatchedAdjoint(complex array)).
elseif Base.stride(Abase,2) == 1 && !(T<:Complex && $TA<:BatchedAdjoint)
newAbase = batched_transpose(_perm12(Abase))
return batched_try_gemm!(C, $fA(newAbase), B)

elseif Base.stride(Bbase,2) == 1 && !(T<:Complex && $TB<:BatchedAdjoint)
newBbase = batched_transpose(_perm12(Bbase))
return batched_try_gemm!(C, A, $fB(newBbase))

# Fallback, e.g when Base.stride(A,3)==1
else
@debug "couldn't re-arrange strides for CUBLAS.gemm_strided_batched!" strides(A) strides(B) strides(C)
NNlib.batched_mul_generic!(C, A, B)
end
C
end
end


# This is obviously the wrong place for this! Not sure where it should go.
# Recursive version, will handle e.g. NamedDimsArray
function Base.unsafe_convert(::Type{CUDAdrv.CuPtr{T}}, A::AbstractArray) where {T}
if A === parent(A)
throw(MethodError(Base.unsafe_convert, Tuple{CUDAdrv.CuPtr{T}, typeof(A)}))
else
return Base.unsafe_convert(CUDAdrv.CuPtr{T}, parent(A))
end
end

NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =
CuArrays.CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C)

# This is https://github.com/JuliaLang/julia/pull/35304, here just for testing now:
Base.similar(A::PermutedDimsArray, T::Type, dims::Base.Dims) = similar(parent(A), T, dims)
# @which Base.similar(PermutedDimsArray(rand(2,2), (2,1)), Int, Base.Dims{2}((3,3)))


# Also the wong place for this, surely.
"""
is_strided_cu(A)
This should return `true` for `A::CuArray`, and also for:
* Any `view(::CuArray)` or `reshape(::CuArray)` etc. which remains a `StridedArray`
* Any other wrapper for which `is_strided_cu(parent(A))`
* Except that `Adjoint(A)` is only unwrapped for real numbers.
Such wrappers include `PermutedDimsArray(::CuArray, ...)`,
but also those defined elsewhere (such as `NamedDimsArray`s)
which are assumed not to break strided-ness.
`Transpose` and `Adjoint` don't currently define `strides`, so for now they return `false`.
"""
is_strided_cu(A::CuArray) = true
is_strided_cu(A) = false
function is_strided_cu(A::AbstractArray)
M = parentmodule(typeof(A))
if parent(A) === A # Array, SparseMatrix, StaticArray
false
elseif M === Base || M === Core || M ===LinearAlgebra
A isa StridedArray && is_strided_cu(parent(A))
else
is_strided_cu(parent(A)) # PermutedDimsArray, NamedDimsArray
end
end

if hasmethod(Base.strides, Tuple{LinearAlgebra.Transpose})
is_strided_cu(A::LinearAlgebra.Transpose) = is_strided(parent(A))
is_strided_cu(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A))
else
is_strided_cu(A::LinearAlgebra.Transpose) = false
is_strided_cu(A::LinearAlgebra.Adjoint) = false
end
23 changes: 10 additions & 13 deletions test/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,22 @@
@test cu(Ca) batched_mul(cu(A), batched_adjoint(cu(B)))
end

using CuArrays: is_strided_cu
using NNlib: is_strided, are_strided, storage_type
using LinearAlgebra
@testset "is_strided_cu" begin
@testset "NNlib storage_type etc." begin

M = cu(ones(10,10))

@test is_strided_cu(M)
@test is_strided_cu(view(M, 1:2:5,:))
@test is_strided_cu(PermutedDimsArray(M, (2,1)))
@test is_strided(M)
@test is_strided(view(M, 1:2:5,:))
@test is_strided(PermutedDimsArray(M, (2,1)))

@test !is_strided_cu(reshape(view(M, 1:2:10,:), 10,:))
@test !is_strided_cu((M.+im)')
@test !is_strided_cu(ones(10,10))
@test !is_strided_cu(Diagonal(ones(3)))
@test !is_strided(reshape(view(M, 1:2:10,:), 10,:))
@test !is_strided((M.+im)')
@test !is_strided(Diagonal(cu(ones(3))))

#=
using NamedDims
@test is_strided(NamedDimsArray(M,(:a, :b))) # and 0.029 ns, 0 allocations
=#
@test storage_type(M) == CuArray{Float32,2,Nothing}
@test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2,Nothing}

end

Expand Down

0 comments on commit 0c6f5c6

Please sign in to comment.