diff --git a/Project.toml b/Project.toml index af474927..e9c3df61 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -28,6 +29,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" AbstractFFTs = "0.4, 0.5" Adapt = "1.0" CEnum = "0.2" +Compat = "3.9" CUDAapi = "3.0, 4.0" CUDAdrv = "6.0.1" CUDAnative = "3.0" diff --git a/src/blas/wrappers.jl b/src/blas/wrappers.jl index 5862655b..765e40cf 100644 --- a/src/blas/wrappers.jl +++ b/src/blas/wrappers.jl @@ -937,15 +937,16 @@ for (fname, elty) in function gemm_strided_batched!(transA::Char, transB::Char, alpha::($elty), - A::CuArray{$elty, 3}, - B::CuArray{$elty, 3}, + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}, beta::($elty), - C::CuArray{$elty, 3}) + C::AbstractArray{$elty, 3}) m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) - @assert size(A, 3) == size(B, 3) == size(C, 3) "Batch size mismatch" + @assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C" + @assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C" if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2) throw(DimensionMismatch("")) @@ -956,10 +957,10 @@ for (fname, elty) in ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - strideA = stride(A, 3) - strideB = stride(B, 3) + strideA = size(A, 3) == 1 ? 0 : stride(A, 3) + strideB = size(B, 3) == 1 ? 0 : stride(B, 3) strideC = stride(C, 3) - batchCount = size(A, 3) + batchCount = size(C, 3) $fname(handle(), cutransA,cutransB, m, n, k, [alpha], A, lda, strideA, B, ldb, strideB, [beta], C, ldc, strideC, batchCount) C @@ -967,15 +968,15 @@ for (fname, elty) in function gemm_strided_batched(transA::Char, transB::Char, alpha::($elty), - A::CuArray{$elty, 3}, - B::CuArray{$elty, 3}) - C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), size(A, 3))) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) + C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3)))) gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C ) end function gemm_strided_batched(transA::Char, transB::Char, - A::CuArray{$elty, 3}, - B::CuArray{$elty, 3}) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) gemm_strided_batched(transA, transB, one($elty), A, B) end end diff --git a/src/nnlib.jl b/src/nnlib.jl index 56063b58..842ac175 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -32,16 +32,8 @@ end # Batched matrix multiplication +# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191 + +NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = + CuArrays.CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) -const batched_gemm_args = [ - (:(CuArray{T, 3}), 'N'), - (:(NNlib.BatchedTranspose{T, <:CuArray{T, 3}}), 'T'), - (:(NNlib.BatchedAdjoint{T, <:CuArray{T, 3}}), 'C') -] - -for (TA, transA) in batched_gemm_args, (TB, transB) in batched_gemm_args - @eval function NNlib.batched_mul!(C::CuArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat} - CuArrays.CUBLAS.gemm_strided_batched!($transA, $transB, one(T), NNlib._unbatch(A), NNlib._unbatch(B), zero(T), C) - C - end -end diff --git a/test/nnlib.jl b/test/nnlib.jl index e8e1462d..12ec19d7 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -16,4 +16,23 @@ @test cu(Ca) ≈ batched_mul(cu(A), batched_adjoint(cu(B))) end +using NNlib: is_strided, are_strided, storage_type +using LinearAlgebra +@testset "NNlib storage_type etc." begin + + M = cu(ones(10,10)) + + @test is_strided(M) + @test is_strided(view(M, 1:2:5,:)) + @test is_strided(PermutedDimsArray(M, (2,1))) + + @test !is_strided(reshape(view(M, 1:2:10,:), 10,:)) + @test !is_strided((M.+im)') + @test !is_strided(Diagonal(cu(ones(3)))) + + @test storage_type(M) == CuArray{Float32,2,Nothing} + @test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2,Nothing} + +end + end