Skip to content

Commit

Permalink
Merge pull request #2574 from amontoison/test_gesvd
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Dec 11, 2024
2 parents 2dae25b + 48ca037 commit 478a952
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
21 changes: 10 additions & 11 deletions lib/cusolver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,31 +389,30 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
A::StridedCuMatrix{$elty})
m, n = size(A)
(m < n) && throw(ArgumentError("CUSOLVER's gesvd requires m ≥ n"))
k = min(m, n)
lda = max(1, stride(A, 2))

U = if jobu === 'A'
similar(A, $elty, (m, m))
elseif jobu == 'S' || jobu === 'O'
similar(A, $elty, (m, min(m, n)))
elseif jobu === 'N'
elseif jobu === 'S'
similar(A, $elty, (m, k))
elseif jobu === 'N' || jobu === 'O'
CU_NULL
else
error("jobu must be one of 'A', 'S', 'O', or 'N'")
end
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))

S = similar(A, $relty, min(m, n))

ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
S = similar(A, $relty, k)
Vt = if jobvt === 'A'
similar(A, $elty, (n, n))
elseif jobvt === 'S' || jobvt === 'O'
similar(A, $elty, (min(m, n), n))
elseif jobvt === 'N'
elseif jobvt === 'S'
similar(A, $elty, (k, n))
elseif jobvt === 'N' || jobvt === 'O'
CU_NULL
else
error("jobvt must be one of 'A', 'S', 'O', or 'N'")
end
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
dh = dense_handle()

function bufferSize()
Expand Down
15 changes: 8 additions & 7 deletions lib/cusolver/dense_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,22 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla
m, n = size(A)
R = real(T)
(m < n) && throw(ArgumentError("The number of rows of A ($m) must be greater or equal to the number of columns of A ($n)"))
k = min(m, n)
U = if jobu == 'A'
CuMatrix{T}(undef, m, m)
elseif jobu == 'S' || jobu == 'O'
CuMatrix{T}(undef, m, min(m, n))
elseif jobu == 'N'
elseif jobu == 'S'
CuMatrix{T}(undef, m, k)
elseif jobu == 'N' || jobu == 'O'
CU_NULL
else
throw(ArgumentError("jobu is incorrect. The values accepted are 'A', 'S', 'O' and 'N'."))
end
Σ = CuVector{R}(undef, min(m, n))
Σ = CuVector{R}(undef, k)
Vt = if jobvt == 'A'
CuMatrix{T}(undef, n, n)
elseif jobvt == 'S' || jobvt == 'O'
CuMatrix{T}(undef, min(m, n), n)
elseif jobvt == 'N'
elseif jobvt == 'S'
CuMatrix{T}(undef, k, n)
elseif jobvt == 'N' || jobvt == 'O'
CU_NULL
else
throw(ArgumentError("jobvt is incorrect. The values accepted are 'A', 'S', 'O' and 'N'."))
Expand Down
17 changes: 17 additions & 0 deletions test/libraries/cusolver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,23 @@ k = 1
@test p h_TAUP
end

@testset "gesvd!" begin
A = rand(elty,m,n)
d_A = CuMatrix(A)
U, Σ, Vt = CUSOLVER.gesvd!('A', 'A', d_A)
@test A collect(U[:,1:n] * Diagonal(Σ) * Vt)

for jobu in ('A', 'S', 'N', 'O')
for jobvt in ('A', 'S', 'N', 'O')
(jobu == 'A') && (jobvt == 'A') && continue
(jobu == 'O') && (jobvt == 'O') && continue
d_A = CuMatrix(A)
U2, Σ2, Vt2 = CUSOLVER.gesvd!(jobu, jobvt, d_A)
@test Σ Σ2
end
end
end

@testset "syevd!" begin
A = rand(elty,m,m)
A += A'
Expand Down
10 changes: 10 additions & 0 deletions test/libraries/cusolver/dense_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ p = 5
d_A = CuMatrix(A)
U, Σ, Vt = CUSOLVER.Xgesvd!('A', 'A', d_A)
@test A collect(U[:,1:n] * Diagonal(Σ) * Vt)

for jobu in ('A', 'S', 'N', 'O')
for jobvt in ('A', 'S', 'N', 'O')
(jobu == 'A') && (jobvt == 'A') && continue
(jobu == 'O') && (jobvt == 'O') && continue
d_A = CuMatrix(A)
U2, Σ2, Vt2 = CUSOLVER.Xgesvd!(jobu, jobvt, d_A)
@test Σ Σ2
end
end
end

@testset "gesvdp!" begin
Expand Down

0 comments on commit 478a952

Please sign in to comment.