Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow PermutedDimsArray in gemm_strided_batched #539

Merged
merged 17 commits into from
Nov 17, 2020
6 changes: 3 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "a8180fd1445e31c0b1add98dae8da694ac2c23fd"
deps = ["Compat", "Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "1ae42464fea5258fd2ff49f1c4a40fc41cba3860"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.6"
version = "0.7.7"

[[OrderedCollections]]
git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db"
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ GPUArrays = "6.1.0"
GPUCompiler = "0.8.1"
LLVM = "3"
MacroTools = "0.5"
NNlib = "0.6.5, 0.7"
NNlib = "0.7.7"
Reexport = "0.2"
Requires = "0.5, 1.0"
TimerOutputs = "0.5"
Expand Down
25 changes: 13 additions & 12 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,15 +923,16 @@ for (fname, elty) in
function gemm_strided_batched!(transA::Char,
transB::Char,
alpha::Number,
A::DenseCuArray{$elty, 3},
B::DenseCuArray{$elty, 3},
A::AbstractArray{$elty, 3}, # allow PermutedDimsArray
B::AbstractArray{$elty, 3},
beta::Number,
C::DenseCuArray{$elty, 3})
C::AbstractArray{$elty, 3})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)

@assert size(A, 3) == size(B, 3) == size(C, 3) "Batch size mismatch"
@assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C"
@assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C"

if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2)
throw(DimensionMismatch(""))
Expand All @@ -940,26 +941,26 @@ for (fname, elty) in
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))

strideA = stride(A, 3)
strideB = stride(B, 3)
strideA = size(A, 3) == 1 ? 0 : stride(A, 3)
strideB = size(B, 3) == 1 ? 0 : stride(B, 3)
strideC = stride(C, 3)
batchCount = size(A, 3)
batchCount = size(C, 3)
$fname(handle(), transA, transB, m, n, k, alpha, A, lda, strideA, B,
ldb, strideB, beta, C, ldc, strideC, batchCount)
C
end
function gemm_strided_batched(transA::Char,
transB::Char,
alpha::Number,
A::DenseCuArray{$elty, 3},
B::DenseCuArray{$elty, 3})
C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), size(A, 3)))
A::AbstractArray{$elty, 3},
B::AbstractArray{$elty, 3})
C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3))))
gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C )
end
function gemm_strided_batched(transA::Char,
transB::Char,
A::DenseCuArray{$elty, 3},
B::DenseCuArray{$elty, 3})
A::AbstractArray{$elty, 3},
B::AbstractArray{$elty, 3})
gemm_strided_batched(transA, transB, one($elty), A, B)
end
end
Expand Down
6 changes: 6 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{
end


## PermutedDimsArray

Base.unsafe_convert(::Type{CuPtr{T}}, A::PermutedDimsArray) where {T} =
Base.unsafe_convert(CuPtr{T}, parent(A))


## reshape

# optimize reshape to return a CuArray
Expand Down
17 changes: 5 additions & 12 deletions src/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,9 @@ end


# Batched matrix multiplication
# 1st argument is produced by NNlib.storage_type(A)
NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =
CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C)

const batched_gemm_args = [
(:(CuArray{T, 3}), 'N'),
(:(NNlib.BatchedTranspose{T, <:CuArray{T, 3}}), 'T'),
(:(NNlib.BatchedAdjoint{T, <:CuArray{T, 3}}), 'C')
]

for (TA, transA) in batched_gemm_args, (TB, transB) in batched_gemm_args
@eval function NNlib.batched_mul!(C::CuArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat}
CUBLAS.gemm_strided_batched!($transA, $transB, one(T), NNlib._unbatch(A), NNlib._unbatch(B), zero(T), C)
C
end
end
Base.unsafe_convert(::Type{CuPtr{T}}, A::NNlib.BatchedAdjOrTrans{T}) where {T} =
Base.unsafe_convert(CuPtr{T}, parent(A))
43 changes: 42 additions & 1 deletion test/nnlib.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using NNlib

@testset "batched_mul" begin
using NNlib: batched_mul, batched_adjoint, batched_transpose
using NNlib: batched_mul, batched_mul!, batched_vec, batched_adjoint, batched_transpose

A = randn(Float32, 3,3,2);
B = randn(Float32, 3,3,2);
Expand All @@ -14,6 +14,47 @@ using NNlib

Ca = batched_mul(A, batched_adjoint(B))
@test CuArray(Ca) ≈ batched_mul(CuArray(A), batched_adjoint(CuArray(B)))

# 5-arg batched_mul!
C .= pi
batched_mul!(C, A, B, 2f0, 3f0)
cuCpi = CuArray(similar(C)) .= pi
@test CuArray(C) ≈ batched_mul!(cuCpi, CuArray(A), CuArray(B), 2f0, 3f0)

# PermutedDimsArray
@test CuArray(Ct) ≈ batched_mul(PermutedDimsArray(CuArray(A), (2,1,3)), CuArray(B))

D = permutedims(B, (1,3,2))
Cp = batched_mul(batched_adjoint(A), B)
@test CuArray(Cp) ≈ batched_mul(batched_adjoint(CuArray(A)), PermutedDimsArray(CuArray(D), (1,3,2)))

# Methods which reshape
M = randn(Float32, 3,3)

Cm = batched_mul(A, M)
@test CuArray(Cm) ≈ batched_mul(CuArray(A), CuArray(M))

Cv = batched_vec(permutedims(A,(3,1,2)), M)
@test CuArray(Cv) ≈ batched_vec(PermutedDimsArray(CuArray(A),(3,1,2)), CuArray(M))
end

@testset "NNlib storage_type etc." begin
using LinearAlgebra
using NNlib: is_strided, are_strided, storage_type

M = cu(ones(10,10))

@test is_strided(M)
@test is_strided(view(M, 1:2:5,:))
@test is_strided(PermutedDimsArray(M, (2,1)))

@test !is_strided(reshape(view(M, 1:2:10,:), 10,:))
@test !is_strided((M .+ im)')
@test !is_strided(Diagonal(cu(ones(3))))

@test storage_type(M) == CuArray{Float32,2}
@test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2}

end

@testset "Broadcast Fix" begin
Expand Down