diff --git a/Project.toml b/Project.toml index 132993f12a..66d2c5beb0 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" @@ -31,6 +32,7 @@ AbstractFFTs = "0.4, 0.5" Adapt = "2.2" BFloat16s = "0.1" CEnum = "0.2, 0.3, 0.4" +Compat = "3.9" DataStructures = "0.17, 0.18" ExprTools = "0.1" GPUArrays = "6.1.0" diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index b2682cae8a..d81faf83f1 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -923,15 +923,16 @@ for (fname, elty) in function gemm_strided_batched!(transA::Char, transB::Char, alpha::Number, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}, + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}, beta::Number, - C::DenseCuArray{$elty, 3}) + C::AbstractArray{$elty, 3}) m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) - @assert size(A, 3) == size(B, 3) == size(C, 3) "Batch size mismatch" + @assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C" + @assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C" if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2) throw(DimensionMismatch("")) @@ -940,10 +941,10 @@ for (fname, elty) in ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - strideA = stride(A, 3) - strideB = stride(B, 3) + strideA = size(A, 3) == 1 ? 0 : stride(A, 3) + strideB = size(B, 3) == 1 ? 0 : stride(B, 3) strideC = stride(C, 3) - batchCount = size(A, 3) + batchCount = size(C, 3) $fname(handle(), transA, transB, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount) C @@ -951,15 +952,15 @@ for (fname, elty) in function gemm_strided_batched(transA::Char, transB::Char, alpha::Number, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}) - C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), size(A, 3))) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) + C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3)))) gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C ) end function gemm_strided_batched(transA::Char, transB::Char, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) gemm_strided_batched(transA, transB, one($elty), A, B) end end diff --git a/src/nnlib.jl b/src/nnlib.jl index 9ffdba88ee..01d0b3a95b 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -23,16 +23,7 @@ end # Batched matrix multiplication +# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191 -const batched_gemm_args = [ - (:(CuArray{T, 3}), 'N'), - (:(NNlib.BatchedTranspose{T, <:CuArray{T, 3}}), 'T'), - (:(NNlib.BatchedAdjoint{T, <:CuArray{T, 3}}), 'C') -] - -for (TA, transA) in batched_gemm_args, (TB, transB) in batched_gemm_args - @eval function NNlib.batched_mul!(C::CuArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat} - CUBLAS.gemm_strided_batched!($transA, $transB, one(T), NNlib._unbatch(A), NNlib._unbatch(B), zero(T), C) - C - end -end + NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = + CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) diff --git a/test/nnlib.jl b/test/nnlib.jl index 930b0a1b7d..df810b4361 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -16,6 +16,25 @@ using NNlib @test CuArray(Ca) ≈ batched_mul(CuArray(A), batched_adjoint(CuArray(B))) end +@testset "NNlib storage_type etc." begin + using LinearAlgebra + using NNlib: is_strided, are_strided, storage_type + + M = cu(ones(10,10)) + + @test is_strided(M) + @test is_strided(view(M, 1:2:5,:)) + @test is_strided(PermutedDimsArray(M, (2,1))) + + @test !is_strided(reshape(view(M, 1:2:10,:), 10,:)) + @test !is_strided((M .+ im)') + @test !is_strided(Diagonal(cu(ones(3)))) + + @test storage_type(M) == CuArray{Float32,2,Nothing} + @test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2,Nothing} + +end + @testset "Broadcast Fix" begin if CUDA.has_cudnn() @test testf(x -> logσ.(x), rand(5))