diff --git a/Manifest.toml b/Manifest.toml index 84230b42f5..a6aa43b692 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -105,10 +105,10 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] -deps = ["Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "a8180fd1445e31c0b1add98dae8da694ac2c23fd" +deps = ["Compat", "Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "1ae42464fea5258fd2ff49f1c4a40fc41cba3860" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.6" +version = "0.7.7" [[OrderedCollections]] git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db" diff --git a/Project.toml b/Project.toml index 132993f12a..cf7ae821cd 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ GPUArrays = "6.1.0" GPUCompiler = "0.8.1" LLVM = "3" MacroTools = "0.5" -NNlib = "0.6.5, 0.7" +NNlib = "0.7.7" Reexport = "0.2" Requires = "0.5, 1.0" TimerOutputs = "0.5" diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index b2682cae8a..68e8a2fd88 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -923,15 +923,16 @@ for (fname, elty) in function gemm_strided_batched!(transA::Char, transB::Char, alpha::Number, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}, + A::AbstractArray{$elty, 3}, # allow PermutedDimsArray + B::AbstractArray{$elty, 3}, beta::Number, - C::DenseCuArray{$elty, 3}) + C::AbstractArray{$elty, 3}) m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) - @assert size(A, 3) == size(B, 3) == size(C, 3) "Batch size mismatch" + @assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C" + @assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C" if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2) throw(DimensionMismatch("")) @@ -940,10 +941,10 @@ for (fname, elty) in ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - strideA = stride(A, 3) - strideB = stride(B, 3) + strideA = size(A, 3) == 1 ? 0 : stride(A, 3) + strideB = size(B, 3) == 1 ? 0 : stride(B, 3) strideC = stride(C, 3) - batchCount = size(A, 3) + batchCount = size(C, 3) $fname(handle(), transA, transB, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount) C @@ -951,15 +952,15 @@ for (fname, elty) in function gemm_strided_batched(transA::Char, transB::Char, alpha::Number, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}) - C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), size(A, 3))) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) + C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3)))) gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C ) end function gemm_strided_batched(transA::Char, transB::Char, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) gemm_strided_batched(transA, transB, one($elty), A, B) end end diff --git a/src/array.jl b/src/array.jl index d63daf3c3f..f790545774 100644 --- a/src/array.jl +++ b/src/array.jl @@ -429,6 +429,12 @@ function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{ end +## PermutedDimsArray + +Base.unsafe_convert(::Type{CuPtr{T}}, A::PermutedDimsArray) where {T} = + Base.unsafe_convert(CuPtr{T}, parent(A)) + + ## reshape # optimize reshape to return a CuArray diff --git a/src/nnlib.jl b/src/nnlib.jl index 9ffdba88ee..1cf65bdd26 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -23,16 +23,9 @@ end # Batched matrix multiplication +# 1st argument is produced by NNlib.storage_type(A) +NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = + CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) -const batched_gemm_args = [ - (:(CuArray{T, 3}), 'N'), - (:(NNlib.BatchedTranspose{T, <:CuArray{T, 3}}), 'T'), - (:(NNlib.BatchedAdjoint{T, <:CuArray{T, 3}}), 'C') -] - -for (TA, transA) in batched_gemm_args, (TB, transB) in batched_gemm_args - @eval function NNlib.batched_mul!(C::CuArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat} - CUBLAS.gemm_strided_batched!($transA, $transB, one(T), NNlib._unbatch(A), NNlib._unbatch(B), zero(T), C) - C - end -end +Base.unsafe_convert(::Type{CuPtr{T}}, A::NNlib.BatchedAdjOrTrans{T}) where {T} = + Base.unsafe_convert(CuPtr{T}, parent(A)) diff --git a/test/nnlib.jl b/test/nnlib.jl index 930b0a1b7d..8f57e27f90 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -1,7 +1,7 @@ using NNlib @testset "batched_mul" begin - using NNlib: batched_mul, batched_adjoint, batched_transpose + using NNlib: batched_mul, batched_mul!, batched_vec, batched_adjoint, batched_transpose A = randn(Float32, 3,3,2); B = randn(Float32, 3,3,2); @@ -14,6 +14,47 @@ using NNlib Ca = batched_mul(A, batched_adjoint(B)) @test CuArray(Ca) ≈ batched_mul(CuArray(A), batched_adjoint(CuArray(B))) + + # 5-arg batched_mul! + C .= pi + batched_mul!(C, A, B, 2f0, 3f0) + cuCpi = CuArray(similar(C)) .= pi + @test CuArray(C) ≈ batched_mul!(cuCpi, CuArray(A), CuArray(B), 2f0, 3f0) + + # PermutedDimsArray + @test CuArray(Ct) ≈ batched_mul(PermutedDimsArray(CuArray(A), (2,1,3)), CuArray(B)) + + D = permutedims(B, (1,3,2)) + Cp = batched_mul(batched_adjoint(A), B) + @test CuArray(Cp) ≈ batched_mul(batched_adjoint(CuArray(A)), PermutedDimsArray(CuArray(D), (1,3,2))) + + # Methods which reshape + M = randn(Float32, 3,3) + + Cm = batched_mul(A, M) + @test CuArray(Cm) ≈ batched_mul(CuArray(A), CuArray(M)) + + Cv = batched_vec(permutedims(A,(3,1,2)), M) + @test CuArray(Cv) ≈ batched_vec(PermutedDimsArray(CuArray(A),(3,1,2)), CuArray(M)) +end + +@testset "NNlib storage_type etc." begin + using LinearAlgebra + using NNlib: is_strided, are_strided, storage_type + + M = cu(ones(10,10)) + + @test is_strided(M) + @test is_strided(view(M, 1:2:5,:)) + @test is_strided(PermutedDimsArray(M, (2,1))) + + @test !is_strided(reshape(view(M, 1:2:10,:), 10,:)) + @test !is_strided((M .+ im)') + @test !is_strided(Diagonal(cu(ones(3)))) + + @test storage_type(M) == CuArray{Float32,2} + @test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2} + end @testset "Broadcast Fix" begin