Skip to content

Commit

Permalink
Merge pull request #9701 from spencerlyon2/ordschur_gen
Browse files Browse the repository at this point in the history
RFC: Ordering by Generalized Eigenvalues for Generalized Schur methods
  • Loading branch information
andreasnoack committed Jan 15, 2015
2 parents 64ad746 + 0223400 commit 2c63b5d
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 76 deletions.
5 changes: 5 additions & 0 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,11 @@ schurfact!{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T}) = Generalized
schurfact{T<:BlasFloat}(A::StridedMatrix{T},B::StridedMatrix{T}) = schurfact!(copy(A),copy(B))
schurfact{TA,TB}(A::StridedMatrix{TA}, B::StridedMatrix{TB}) = (S = promote_type(Float32,typeof(one(TA)/norm(one(TA))),TB); schurfact!(S != TA ? convert(AbstractMatrix{S},A) : copy(A), S != TB ? convert(AbstractMatrix{S},B) : copy(B)))

ordschur!{Ty<:BlasFloat}(S::StridedMatrix{Ty}, T::StridedMatrix{Ty}, Q::StridedMatrix{Ty}, Z::StridedMatrix{Ty}, select::Array{Int}) = GeneralizedSchur(LinAlg.LAPACK.tgsen!(select, S, T, Q, Z)...)
ordschur{Ty<:BlasFloat}(S::StridedMatrix{Ty}, T::StridedMatrix{Ty}, Q::StridedMatrix{Ty}, Z::StridedMatrix{Ty}, select::Array{Int}) = ordschur!(copy(S), copy(T), copy(Q), copy(Z), select)
ordschur!{Ty<:BlasFloat}(gschur::GeneralizedSchur{Ty}, select::Array{Int}) = (res=ordschur!(gschur.S, gschur.T, gschur.Q, gschur.Z, select); gschur[:alpha][:]=res[:alpha]; gschur[:beta][:]=res[:beta]; res)
ordschur{Ty<:BlasFloat}(gschur::GeneralizedSchur{Ty}, select::Array{Int}) = ordschur(gschur.S, gschur.T, gschur.Q, gschur.Z, select)

function getindex(F::GeneralizedSchur, d::Symbol)
d == :S && return F.S
d == :T && return F.T
Expand Down
170 changes: 146 additions & 24 deletions base/linalg/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1245,8 +1245,8 @@ for (geevx, ggev, elty) in
chkstride1(A,B)
n, m = chksquare(A,B)
n==m || throw(DimensionMismatch("matrices must have same size"))
lda = max(1, n)
ldb = max(1, n)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
alphar = similar(A, $elty, n)
alphai = similar(A, $elty, n)
beta = similar(A, $elty, n)
Expand Down Expand Up @@ -1351,7 +1351,8 @@ for (geevx, ggev, elty, relty) in
chkstride1(A, B)
n, m = chksquare(A, B)
n==m || throw(DimensionMismatch("matrices must have same size"))
lda = ldb = max(1, n)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
alpha = similar(A, $elty, n)
beta = similar(A, $elty, n)
ldvl = jobvl == 'V' ? n : 1
Expand Down Expand Up @@ -2920,7 +2921,8 @@ for (syev, syevr, sygvd, elty) in
chkstride1(A, B)
n, m = chksquare(A, B)
n==m || throw(DimensionMismatch("Matrices must have same size"))
lda = ldb = max(1, n)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
w = similar(A, $elty, n)
work = Array($elty, 1)
lwork = -one(BlasInt)
Expand Down Expand Up @@ -3071,7 +3073,8 @@ for (syev, syevr, sygvd, elty, relty) in
chkstride1(A, B)
n, m = chksquare(A, B)
n==m || throw(DimensionMismatch("Matrices must have same size"))
lda = ldb = max(1, n)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
w = similar(A, $relty, n)
work = Array($elty, 1)
lwork = -one(BlasInt)
Expand Down Expand Up @@ -3307,7 +3310,7 @@ for (gehrd, elty) in
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{BlasInt}),
&n, &ilo, &ihi, A,
&max(1,n), tau, work, &lwork,
&max(1, stride(A, 2)), tau, work, &lwork,
info)
@lapackerror
if lwork < 0
Expand Down Expand Up @@ -3346,7 +3349,7 @@ for (orghr, elty) in
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{BlasInt}),
&n, &ilo, &ihi, A,
&max(1,n), tau, work, &lwork,
&max(1, stride(A, 2)), tau, work, &lwork,
info)
@lapackerror
if lwork < 0
Expand Down Expand Up @@ -3389,7 +3392,7 @@ for (gees, gges, elty) in
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{Void}, Ptr{BlasInt}),
&jobvs, &'N', C_NULL, &n,
A, &max(1, n), sdim, wr,
A, &max(1, stride(A, 2)), sdim, wr,
wi, vs, &ldvs, work,
&lwork, C_NULL, info)
@lapackerror
Expand Down Expand Up @@ -3433,8 +3436,8 @@ for (gees, gges, elty) in
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{Void},
Ptr{BlasInt}),
&jobvsl, &jobvsr, &'N', C_NULL,
&n, A, &max(1,n), B,
&max(1,n), &sdim, alphar, alphai,
&n, A, &max(1,stride(A, 2)), B,
&max(1,stride(B, 2)), &sdim, alphar, alphai,
beta, vsl, &ldvsl, vsr,
&ldvsr, work, &lwork, C_NULL,
info)
Expand Down Expand Up @@ -3479,7 +3482,7 @@ for (gees, gges, elty, relty) in
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$relty}, Ptr{Void}, Ptr{BlasInt}),
&jobvs, &sort, C_NULL, &n,
A, &max(1, n), &sdim, w,
A, &max(1, stride(A, 2)), &sdim, w,
vs, &ldvs, work, &lwork,
rwork, C_NULL, info)
@lapackerror
Expand Down Expand Up @@ -3524,8 +3527,8 @@ for (gees, gges, elty, relty) in
Ptr{$elty}, Ptr{BlasInt}, Ptr{$relty}, Ptr{Void},
Ptr{BlasInt}),
&jobvsl, &jobvsr, &'N', C_NULL,
&n, A, &max(1,n), B,
&max(1,n), &sdim, alpha, beta,
&n, A, &max(1, stride(A, 2)), B,
&max(1, stride(B, 2)), &sdim, alpha, beta,
vsl, &ldvsl, vsr, &ldvsr,
work, &lwork, rwork, C_NULL,
info)
Expand All @@ -3540,9 +3543,9 @@ for (gees, gges, elty, relty) in
end
end
# Reorder Schur forms
for (trsen, elty) in
((:dtrsen_,:Float64),
(:strsen_,:Float32))
for (trsen, tgsen, elty) in
((:dtrsen_, :dtgsen_, :Float64),
(:strsen_, :stgsen_, :Float32))
@eval begin
function trsen!(select::Array{Int}, T::StridedMatrix{$elty}, Q::StridedMatrix{$elty})
# * .. Scalar Arguments ..
Expand All @@ -3556,7 +3559,8 @@ for (trsen, elty) in
# DOUBLE PRECISION Q( LDQ, * ), T( LDT, * ), WI( * ), WORK( * ), WR( * )
chkstride1(T, Q)
n = chksquare(T)
ld = max(1, n)
ldt = max(1, stride(T, 2))
ldq = max(1, stride(Q, 2))
wr = similar(T, $elty, n)
wi = similar(T, $elty, n)
m = sum(select)
Expand All @@ -3572,10 +3576,10 @@ for (trsen, elty) in
(Ptr{BlasChar}, Ptr{BlasChar}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{Void}, Ptr{Void},
Ptr{$elty}, Ptr {BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}),
&'N', &'V', select, &n,
T, &ld, Q, &ld,
T, &ldt, Q, &ldq,
wr, wi, &m, C_NULL, C_NULL,
work, &lwork, iwork, &liwork,
info)
Expand All @@ -3589,12 +3593,71 @@ for (trsen, elty) in
end
T, Q, all(wi .== 0) ? wr : complex(wr, wi)
end
function tgsen!(select::Array{Int}, S::StridedMatrix{$elty}, T::StridedMatrix{$elty},
Q::StridedMatrix{$elty}, Z::StridedMatrix{$elty})
# * .. Scalar Arguments ..
# * LOGICAL WANTQ, WANTZ
# * INTEGER IJOB, INFO, LDA, LDB, LDQ, LDZ, LIWORK, LWORK,
# * $ M, N
# * DOUBLE PRECISION PL, PR
# * ..
# * .. Array Arguments ..
# * LOGICAL SELECT( * )
# * INTEGER IWORK( * )
# * DOUBLE PRECISION A( LDA, * ), ALPHAI( * ), ALPHAR( * ),
# * $ B( LDB, * ), BETA( * ), DIF( * ), Q( LDQ, * ),
# * $ WORK( * ), Z( LDZ, * )
# * ..
chkstride1(S, T, Q, Z)
n, nt, nq, nz = chksquare(S, T, Q, Z)
n==nt==nq==nz || throw(DimensionMismatch("matrices are not of same size"))
lds = max(1, stride(S, 2))
ldt = max(1, stride(T, 2))
ldq = max(1, stride(Q, 2))
ldz = max(1, stride(Z, 2))
m = sum(select)
alphai = similar(T, $elty, n)
alphar = similar(T, $elty, n)
beta = similar(T, $elty, n)
lwork = blas_int(-1)
work = Array($elty, 1)
liwork = blas_int(-1)
iwork = Array(BlasInt, 1)
info = Array(BlasInt, 1)
select = convert(Array{BlasInt}, select)

for i = 1:2
ccall(($(blasfunc(tgsen)), liblapack), Void,
(Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{Void}, Ptr{Void}, Ptr{Void},
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}),
&0, &1, &1, select,
&n, S, &lds, T,
&ldt, alphar, alphai, beta,
Q, &ldq, Z, &ldz,
&m, C_NULL, C_NULL, C_NULL,
work, &lwork, iwork, &liwork,
info)
@lapackerror
if i == 1 # only estimated optimal lwork, liwork
lwork = blas_int(real(work[1]))
work = Array($elty, lwork)
liwork = blas_int(real(iwork[1]))
iwork = Array(BlasInt, liwork)
end
end
S, T, complex(alphar, alphai), beta, Q, Z
end
end
end

for (trsen, elty) in
((:ztrsen_,:Complex128),
(:ctrsen_,:Complex64))
for (trsen, tgsen, elty) in
((:ztrsen_, :ztgsen_, :Complex128),
(:ctrsen_, :ctgsen_, :Complex64))
@eval begin
function trsen!(select::Array{Int}, T::StridedMatrix{$elty}, Q::StridedMatrix{$elty})
# * .. Scalar Arguments ..
Expand All @@ -3607,7 +3670,8 @@ for (trsen, elty) in
# COMPLEX Q( LDQ, * ), T( LDT, * ), W( * ), WORK( * )
chkstride1(T, Q)
n = chksquare(T)
ld = max(1, n)
ldt = max(1, stride(T, 2))
ldq = max(1, stride(Q, 2))
w = similar(T, $elty, n)
m = sum(select)
work = Array($elty, 1)
Expand All @@ -3623,7 +3687,7 @@ for (trsen, elty) in
Ptr{$elty}, Ptr {BlasInt},
Ptr{BlasInt}),
&'N', &'V', select, &n,
T, &ld, Q, &ld,
T, &ldt, Q, &ldq,
w, &m, C_NULL, C_NULL,
work, &lwork,
info)
Expand All @@ -3635,6 +3699,64 @@ for (trsen, elty) in
end
T, Q, w
end
function tgsen!(select::Array{Int}, S::StridedMatrix{$elty}, T::StridedMatrix{$elty},
Q::StridedMatrix{$elty}, Z::StridedMatrix{$elty})
# * .. Scalar Arguments ..
# * LOGICAL WANTQ, WANTZ
# * INTEGER IJOB, INFO, LDA, LDB, LDQ, LDZ, LIWORK, LWORK,
# * $ M, N
# * DOUBLE PRECISION PL, PR
# * ..
# * .. Array Arguments ..
# * LOGICAL SELECT( * )
# * INTEGER IWORK( * )
# * DOUBLE PRECISION DIF( * )
# * COMPLEX*16 A( LDA, * ), ALPHA( * ), B( LDB, * ),
# * $ BETA( * ), Q( LDQ, * ), WORK( * ), Z( LDZ, * )
# * ..
chkstride1(S, T, Q, Z)
n, nt, nq, nz = chksquare(S, T, Q, Z)
n==nt==nq==nz || throw(DimensionMismatch("matrices are not of same size"))
lds = max(1, stride(S, 2))
ldt = max(1, stride(T, 2))
ldq = max(1, stride(Q, 2))
ldz = max(1, stride(Z, 2))
m = sum(select)
alpha = similar(T, $elty, n)
beta = similar(T, $elty, n)
lwork = blas_int(-1)
work = Array($elty, 1)
liwork = blas_int(-1)
iwork = Array(BlasInt, 1)
info = Array(BlasInt, 1)
select = convert(Array{BlasInt}, select)

for i = 1:2
ccall(($(blasfunc(tgsen)), liblapack), Void,
(Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{Void}, Ptr{Void}, Ptr{Void},
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}),
&0, &1, &1, select,
&n, S, &lds, T,
&ldt, alpha, beta,
Q, &ldq, Z, &ldz,
&m, C_NULL, C_NULL, C_NULL,
work, &lwork, iwork, &liwork,
info)
@lapackerror
if i == 1 # only estimated optimal lwork, liwork
lwork = blas_int(real(work[1]))
work = Array($elty, lwork)
liwork = blas_int(real(iwork[1]))
iwork = Array(BlasInt, liwork)
end
end
S, T, alpha, beta, Q, Z
end
end
end

Expand Down
Loading

0 comments on commit 2c63b5d

Please sign in to comment.