Skip to content

Commit

Permalink
Add unit_diag option for sv2! functions
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Nov 12, 2020
1 parent e09bf71 commit 2d7180f
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 149 deletions.
8 changes: 7 additions & 1 deletion lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@ Base.:(\)(transA::Transpose{T, LowerTriangular{T, S}}, B::DenseCuMatrix{T}) wher
Base.:(\)(adjA::Adjoint{T, UpperTriangular{T, S}},B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O')
Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O')

Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('N',A,B,'O')
Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('N',A,B,'O')
Base.:(\)(transA::Transpose{T, UpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O')
Base.:(\)(transA::Transpose{T, LowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O')
Base.:(\)(adjA::Adjoint{T, UpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O')
Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O')

Base.:(\)(A::Union{UnitUpperTriangular{T, S},UnitLowerTriangular{T, S}}, B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('N',A,B,'O',unit_diag=true)
Base.:(\)(transA::Transpose{T, UnitUpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O',unit_diag=true)
Base.:(\)(transA::Transpose{T, UnitLowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O',unit_diag=true)
Base.:(\)(adjA::Adjoint{T, UnitUpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O',unit_diag=true)
Base.:(\)(adjA::Adjoint{T, UnitLowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O',unit_diag=true)

Base.:(+)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC},B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,B,'O','O','O')
Base.:(-)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC},B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,-one(eltype(A)),B,'O','O','O')
51 changes: 31 additions & 20 deletions lib/cusparse/level2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ for (fname,elty) in ((:cusparseSbsrmv, :Float32),
end

"""
sv2!(transa::SparseChar, uplo::SparseChar, alpha::BlasFloat, A::CuSparseMatrixBSR, X::CuVector, index::SparseChar)
sv2!(transa::SparseChar, uplo::SparseChar, alpha::BlasFloat, A::CuSparseMatrixBSR, X::CuVector, index::SparseChar; unit_diag::Bool=false)
Performs `X = alpha * op(A) \\ X `, where `op` can be nothing (`transa = N`), tranpose
(`transa = T`) or conjugate transpose (`transa = C`). `X` is a dense vector, and `uplo`
tells `sv2!` which triangle of the block sparse matrix `A` to reference.
If the triangle has unit diagonal, set `unit_diag` to true.
"""
sv2!(transa::SparseChar, uplo::SparseChar, alpha::BlasFloat, A::CuSparseMatrixBSR, X::CuVector, index::SparseChar)
sv2!(transa::SparseChar, uplo::SparseChar, alpha::BlasFloat, A::CuSparseMatrixBSR, X::CuVector, index::SparseChar; unit_diag::Bool=false)
# bsrsv2
for (bname,aname,sname,elty) in ((:cusparseSbsrsv2_bufferSize, :cusparseSbsrsv2_analysis, :cusparseSbsrsv2_solve, :Float32),
(:cusparseDbsrsv2_bufferSize, :cusparseDbsrsv2_analysis, :cusparseDbsrsv2_solve, :Float64),
Expand All @@ -62,9 +63,11 @@ for (bname,aname,sname,elty) in ((:cusparseSbsrsv2_bufferSize, :cusparseSbsrsv2_
alpha::Number,
A::CuSparseMatrixBSR{$elty},
X::CuVector{$elty},
index::SparseChar)
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, uplo, CUSPARSE_DIAG_TYPE_NON_UNIT, index)
m,n = A.dims
index::SparseChar;
unit_diag::Bool=false)
DIAG_TYPE = (unit_diag ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT)
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, uplo, DIAG_TYPE, index)
m,n = A.dims
if m != n
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
end
Expand Down Expand Up @@ -106,36 +109,40 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
alpha::Number,
A::CuSparseMatrix{$elty},
X::CuVector{$elty},
index::SparseChar)
sv2!(transa,uplo,alpha,A,copy(X),index)
index::SparseChar;
unit_diag::Bool=false)
sv2!(transa,uplo,alpha,A,copy(X),index,unit_diag=unit_diag)
end
function sv2(transa::SparseChar,
uplo::SparseChar,
A::CuSparseMatrix{$elty},
X::CuVector{$elty},
index::SparseChar)
sv2!(transa,uplo,one($elty),A,copy(X),index)
index::SparseChar;
unit_diag::Bool=false)
sv2!(transa,uplo,one($elty),A,copy(X),index,unit_diag=unit_diag)
end
function sv2(transa::SparseChar,
alpha::Number,
A::AbstractTriangular,
X::CuVector{$elty},
index::SparseChar)
index::SparseChar;
unit_diag::Bool=false)
uplo = 'U'
if istril(A)
if typeof(A) <: Union{LowerTriangular, UnitLowerTriangular}
uplo = 'L'
end
sv2!(transa,uplo,alpha,A.data,copy(X),index)
sv2!(transa,uplo,alpha,A.data,copy(X),index,unit_diag=unit_diag)
end
function sv2(transa::SparseChar,
A::AbstractTriangular,
X::CuVector{$elty},
index::SparseChar)
index::SparseChar;
unit_diag::Bool=false)
uplo = 'U'
if istril(A)
if typeof(A) <: Union{LowerTriangular, UnitLowerTriangular}
uplo = 'L'
end
sv2!(transa,uplo,one($elty),A.data,copy(X),index)
sv2!(transa,uplo,one($elty),A.data,copy(X),index,unit_diag=unit_diag)
end
end
end
Expand All @@ -151,8 +158,10 @@ for (bname,aname,sname,elty) in ((:cusparseScsrsv2_bufferSize, :cusparseScsrsv2_
alpha::Number,
A::CuSparseMatrixCSR{$elty},
X::CuVector{$elty},
index::SparseChar)
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, uplo, CUSPARSE_DIAG_TYPE_NON_UNIT, index)
index::SparseChar;
unit_diag::Bool=false)
DIAG_TYPE = (unit_diag ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT)
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, uplo, DIAG_TYPE, index)
m,n = A.dims
if m != n
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
Expand Down Expand Up @@ -198,7 +207,8 @@ for (bname,aname,sname,elty) in ((:cusparseScsrsv2_bufferSize, :cusparseScsrsv2_
alpha::Number,
A::CuSparseMatrixCSC{$elty},
X::CuVector{$elty},
index::SparseChar)
index::SparseChar;
unit_diag::Bool=false)
ctransa = 'N'
cuplo = 'U'
if transa == 'N'
Expand All @@ -207,8 +217,9 @@ for (bname,aname,sname,elty) in ((:cusparseScsrsv2_bufferSize, :cusparseScsrsv2_
if uplo == 'U'
cuplo = 'L'
end
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, cuplo, CUSPARSE_DIAG_TYPE_NON_UNIT, index)
n,m = A.dims
DIAG_TYPE = (unit_diag ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT)
desc = CuMatrixDescriptor(CUSPARSE_MATRIX_TYPE_GENERAL, cuplo, DIAG_TYPE, index)
n,m = A.dims
if m != n
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
end
Expand Down
Loading

0 comments on commit 2d7180f

Please sign in to comment.