Skip to content

Commit

Permalink
choose method using is_strided_cu(A)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Apr 1, 2020
1 parent 20851f2 commit 9348904
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 14 deletions.
91 changes: 91 additions & 0 deletions src/CuArrays.jl.save
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
module CuArrays

using CUDAapi, CUDAdrv, CUDAnative

using GPUArrays

export CuArray, CuVector, CuMatrix, CuVecOrMat, cu
export CUBLAS, CUSPARSE, CUSOLVER, CUFFT, CURAND, CUDNN, CUTENSOR

import LinearAlgebra

using Adapt

using Libdl

using Requires


## source code includes

include("bindeps.jl")

# core array functionality
include("memory.jl")
include("array.jl")
include("gpuarrays.jl")
include("subarray.jl")
include("utils.jl")

# integrations and specialized functionality
include("indexing.jl")
include("broadcast.jl")
include("mapreduce.jl")
include("accumulate.jl")
include("linalg.jl")

# vendor libraries
include("blas/CUBLAS.jl")
include("sparse/CUSPARSE.jl")
include("solver/CUSOLVER.jl")
include("fft/CUFFT.jl")
include("rand/CURAND.jl")
include("dnn/CUDNN.jl")
include("tensor/CUTENSOR.jl")

include("deprecated.jl")

include("nnlib.jl")

## initialization

const __initialized__ = Ref(false)
functional() = __initialized__[]

export has_cudnn, has_cutensor
has_cudnn() = Libdl.dlopen_e(CUDNN.libcudnn[]) !== C_NULL
has_cutensor() = Libdl.dlopen_e(CUTENSOR.libcutensor[]) !== C_NULL

function __init__()
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
silent = parse(Bool, get(ENV, "JULIA_CUDA_SILENT", "false")) || precompiling
verbose = parse(Bool, get(ENV, "JULIA_CUDA_VERBOSE", "false"))

# if any dependent GPU package failed, expect it to have logged an error and bail out
if !CUDAdrv.functional() || !CUDAnative.functional()
verbose && @warn "CuArrays.jl did not initialize because CUDAdrv.jl or CUDAnative.jl failed to"
return
end

try
__init_bindeps__(silent=silent, verbose=verbose)

# package integrations
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")

__init_memory__()

__initialized__[] = true
catch ex
# don't actually fail to keep the package loadable
if !silent
if verbose
@error "CuArrays.jl failed to initialize" exception=(ex, catch_backtrace())
else
@info "CuArrays.jl failed to initialize and will be unavailable (set JULIA_CUDA_SILENT or JULIA_CUDA_VERBOSE to silence or expand this message)"
end
end
end
end

end # module
91 changes: 91 additions & 0 deletions src/CuArrays.jl.save.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
module CuArrays

using CUDAapi, CUDAdrv, CUDAnative

using GPUArrays

export CuArray, CuVector, CuMatrix, CuVecOrMat, cu
export CUBLAS, CUSPARSE, CUSOLVER, CUFFT, CURAND, CUDNN, CUTENSOR

import LinearAlgebra

using Adapt

using Libdl

using Requires


## source code includes

include("bindeps.jl")

# core array functionality
include("memory.jl")
include("array.jl")
include("gpuarrays.jl")
include("subarray.jl")
include("utils.jl")

# integrations and specialized functionality
include("indexing.jl")
include("broadcast.jl")
include("mapreduce.jl")
include("accumulate.jl")
include("linalg.jl")

# vendor libraries
include("blas/CUBLAS.jl")
include("sparse/CUSPARSE.jl")
include("solver/CUSOLVER.jl")
include("fft/CUFFT.jl")
include("rand/CURAND.jl")
include("dnn/CUDNN.jl")
include("tensor/CUTENSOR.jl")

include("deprecated.jl")

include("nnlib.jl")

## initialization

const __initialized__ = Ref(false)
functional() = __initialized__[]

export has_cudnn, has_cutensor
has_cudnn() = Libdl.dlopen_e(CUDNN.libcudnn[]) !== C_NULL
has_cutensor() = Libdl.dlopen_e(CUTENSOR.libcutensor[]) !== C_NULL

function __init__()
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
silent = parse(Bool, get(ENV, "JULIA_CUDA_SILENT", "false")) || precompiling
verbose = parse(Bool, get(ENV, "JULIA_CUDA_VERBOSE", "false"))

# if any dependent GPU package failed, expect it to have logged an error and bail out
if !CUDAdrv.functional() || !CUDAnative.functional()
verbose && @warn "CuArrays.jl did not initialize because CUDAdrv.jl or CUDAnative.jl failed to"
return
end

try
__init_bindeps__(silent=silent, verbose=verbose)

# package integrations
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")

__init_memory__()

__initialized__[] = true
catch ex
# don't actually fail to keep the package loadable
if !silent
if verbose
@error "CuArrays.jl failed to initialize" exception=(ex, catch_backtrace())
else
@info "CuArrays.jl failed to initialize and will be unavailable (set JULIA_CUDA_SILENT or JULIA_CUDA_VERBOSE to silence or expand this message)"
end
end
end
end

end # module
84 changes: 70 additions & 14 deletions src/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,43 +33,62 @@ end

# Batched matrix multiplication

# This method has a slightly tighter signature than the one in NNlib, all same eltype.
function NNlib.batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{T,3}, B::AbstractArray{T,3}) where {T<:CUBLAS.CublasFloat}
if is_strided_cu(A) && is_strided_cu(B) && is_strided_cu(C)
# Data is on GPU, and it's safe to call strides(A). gemm_strided_batched may be legal.
batched_try_gemm!(C, A, B)

elseif is_strided_cu(A) || is_strided_cu(B) || is_strided_cu(C)
# This is hopeless, but best option is the fallback
@debug "weird mix of CPU + GPU?"
NNlib.batched_mul_generic!(C, A, B)

else
# All cases for CPU gemm! will come through here, is_strided_cu(A) compiles away:
NNlib.batched_mul_cpu!(C, A, B)
end
end

const batched_gemm_args = [
(:(AbstractArray{T, 3}), 'N', :identity),
(:(NNlib.BatchedTranspose{T, <:AbstractArray{T, 3}}), 'T', :batched_transpose),
(:(NNlib.BatchedAdjoint{T, <:AbstractArray{T, 3}}), 'C', :batched_adjoint)
(:(AbstractArray{T, 3}), 'N', :identity),
(:(NNlib.BatchedTranspose{T}), 'T', :batched_transpose),
(:(NNlib.BatchedAdjoint{T}), 'C', :batched_adjoint)
]

using NNlib: batched_mul!, BatchedTranspose, BatchedAdjoint, batched_transpose, batched_adjoint
using NNlib: _unbatch, _perm12

for (TA, transA, fA) in batched_gemm_args, (TB, transB, fB) in batched_gemm_args
@eval function NNlib.batched_mul!(C::CuArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat}
@eval function batched_try_gemm!(C::AbstractArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat}

Abase, Bbase = NNlib._unbatch(A), NNlib._unbatch(B)
Abase, Bbase = _unbatch(A), _unbatch(B)

# Best case, we can call batched_gemm! immediately:
if Base.stride(Abase,1) == Base.stride(Bbase,1) == Base.stride(C,1) == 1
CuArrays.CUBLAS.gemm_strided_batched!($transA, $transB, one(T), NNlib._unbatch(A), NNlib._unbatch(B), zero(T), C)
CuArrays.CUBLAS.gemm_strided_batched!($transA, $transB, one(T), Abase, Bbase, zero(T), C)

# Second-best, can we fix it by Perm.ing the base, and adjusing 'T' label?
# But only if we won't produce BatchedTranspose(BatchedAdjoint(complex array)).
elseif Base.stride(Abase,2) == 1 && !(T<:Complex && $TA<:NNlib.BatchedAdjoint)
newAbase = NNlib.batched_transpose(PermutedDimsArray(Abase, (2,1,3)))
return NNlib.batched_mul!(C, $fA(newAbase), B)
elseif Base.stride(Bbase,2) == 1 && !(T<:Complex && $TB<:NNlib.BatchedAdjoint)
newBbase = NNlib.batched_transpose(PermutedDimsArray(Bbase, (2,1,3)))
return NNlib.batched_mul!(C, A, $fB(newBbase))
elseif Base.stride(Abase,2) == 1 && !(T<:Complex && $TA<:BatchedAdjoint)
newAbase = batched_transpose(_perm12(Abase))
return batched_try_gemm!(C, $fA(newAbase), B)

elseif Base.stride(Bbase,2) == 1 && !(T<:Complex && $TB<:BatchedAdjoint)
newBbase = batched_transpose(_perm12(Bbase))
return batched_try_gemm!(C, A, $fB(newBbase))

# Fallback, e.g when Base.stride(A,3)==1
else
@debug "couldn't re-arrange strides for CUBLAS.gemm_strided_batched!" strides(A) strides(B) strides(C)
NNlib.batched_mul_generic!(C, A, B)
end
C
end
end


# This is obviously the wrong place for this! Not sure where it should go.
# Base.unsafe_convert(::Type{CUDAdrv.CuPtr{T}}, A::PermutedDimsArray) where {T} =
# Base.unsafe_convert(CUDAdrv.CuPtr{T}, parent(A))
# Recursive version, will handle e.g. NamedDimsArray
function Base.unsafe_convert(::Type{CUDAdrv.CuPtr{T}}, A::AbstractArray) where {T}
if A === parent(A)
Expand All @@ -82,3 +101,40 @@ end

# This is https://github.com/JuliaLang/julia/pull/35304, here just for testing now:
Base.similar(A::PermutedDimsArray, T::Type, dims::Base.Dims) = similar(parent(A), T, dims)


# Also the wong place for this, surely.
"""
is_strided_cu(A)
This should return `true` for `A::CuArray`, and also for:
* Any `view(::CuArray)` or `reshape(::CuArray)` etc. which remains a `StridedArray`
* Any other wrapper for which `is_strided_cu(parent(A))`
* Except that `Adjoint(A)` is only unwrapped for real numbers.
Such wrappers include `PermutedDimsArray(::CuArray, ...)`,
but also those defined elsewhere (such as `NamedDimsArray`s)
which are assumed not to break strided-ness.
`Transpose` and `Adjoint` don't currently define `strides`, so for now they return `false`.
"""
is_strided_cu(A::CuArray) = true
is_strided_cu(A) = false
function is_strided_cu(A::AbstractArray)
M = parentmodule(typeof(A))
if parent(A) === A # Array, SparseMatrix, StaticArray
false
elseif M === Base || M === Core || M ===LinearAlgebra
A isa StridedArray && is_strided_cu(parent(A))
else
is_strided_cu(parent(A)) # PermutedDimsArray, NamedDimsArray
end
end

if hasmethod(Base.strides, Tuple{LinearAlgebra.Transpose})
is_strided_cu(A::LinearAlgebra.Transpose) = is_strided(parent(A))
is_strided_cu(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A))
else
is_strided_cu(A::LinearAlgebra.Transpose) = false
is_strided_cu(A::LinearAlgebra.Adjoint) = false
end
22 changes: 22 additions & 0 deletions test/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,26 @@
@test cu(Ca) batched_mul(cu(A), batched_adjoint(cu(B)))
end

using CuArrays: is_strided_cu
using LinearAlgebra
@testset "is_strided_cu" begin

M = cu(ones(10,10))

@test is_strided_cu(M)
@test is_strided_cu(view(M, 1:2:5,:))
@test is_strided_cu(PermutedDimsArray(M, (2,1)))

@test !is_strided_cu(reshape(view(M, 1:2:10,:), 10,:))
@test !is_strided_cu((M.+im)')
@test !is_strided_cu(ones(10,10))
@test !is_strided_cu(Diagonal(ones(3)))

#=
using NamedDims
@test is_strided(NamedDimsArray(M,(:a, :b))) # and 0.029 ns, 0 allocations
=#

end

end

0 comments on commit 9348904

Please sign in to comment.