diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index d81faf83f1..0a95050188 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -923,10 +923,20 @@ for (fname, elty) in function gemm_strided_batched!(transA::Char, transB::Char, alpha::Number, - A::AbstractArray{$elty, 3}, + A::DenseCuArray{$elty, 3}, + B::DenseCuArray{$elty, 3}, + beta::Number, + C::DenseCuArray{$elty, 3}) + _gemm_strided_batched(transA, transB, alpha, A, B, beta, C) + end + function _gemm_strided_batched!(transA::Char, + transB::Char, + alpha::Number, + A::AbstractArray{$elty, 3}, # allows PermutedDimsArray B::AbstractArray{$elty, 3}, beta::Number, C::AbstractArray{$elty, 3}) + m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) @@ -952,15 +962,15 @@ for (fname, elty) in function gemm_strided_batched(transA::Char, transB::Char, alpha::Number, - A::AbstractArray{$elty, 3}, - B::AbstractArray{$elty, 3}) + A::DenseCuArray{$elty, 3}, + B::DenseCuArray{$elty, 3}) C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3)))) gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C ) end function gemm_strided_batched(transA::Char, transB::Char, - A::AbstractArray{$elty, 3}, - B::AbstractArray{$elty, 3}) + A::DenseCuArray{$elty, 3}, + B::DenseCuArray{$elty, 3}) gemm_strided_batched(transA, transB, one($elty), A, B) end end diff --git a/src/nnlib.jl b/src/nnlib.jl index 01d0b3a95b..2da7149223 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -23,7 +23,5 @@ end # Batched matrix multiplication -# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191 - - NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = - CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) +NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = + CUBLAS._gemm_strided_batched!(transA, transB, α, A, B, β, C)