diff --git a/src/PermMatrix.jl b/src/PermMatrix.jl index c6becbd..d6d0d70 100644 --- a/src/PermMatrix.jl +++ b/src/PermMatrix.jl @@ -1,4 +1,4 @@ -abstract type AbstractPermMatrix{Tv} <: AbstractMatrix{Tv} end +abstract type AbstractPermMatrix{Tv, Ti} <: AbstractMatrix{Tv} end """ PermMatrix{Tv, Ti}(perm::AbstractVector{Ti}, vals::AbstractVector{Tv}) where {Tv, Ti<:Integer} PermMatrix(perm::Vector{Ti}, vals::Vector{Tv}) where {Tv, Ti} @@ -25,7 +25,7 @@ julia> PermMatrix([2,1,4,3], rand(4)) ``` """ struct PermMatrix{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <: - AbstractPermMatrix{Tv} + AbstractPermMatrix{Tv,Ti} perm::Vi # new orders vals::Vv # multiplied values. @@ -56,7 +56,7 @@ end # the column major version of `PermMatrix` struct PermMatrixCSC{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <: - AbstractPermMatrix{Tv} + AbstractPermMatrix{Tv,Ti} perm::Vi # new orders vals::Vv # multiplied values. diff --git a/src/arraymath.jl b/src/arraymath.jl index d26045d..471eb2f 100644 --- a/src/arraymath.jl +++ b/src/arraymath.jl @@ -9,17 +9,18 @@ Base.imag(M::IMatrix{T}) where {T} = Diagonal(zeros(T, M.n)) # PermMatrix for func in (:conj, :real, :imag) - @eval (Base.$func)(M::PermMatrix) = PermMatrix(M.perm, ($func)(M.vals)) + @eval (Base.$func)(M::AbstractPermMatrix) = basetype(M)(M.perm, ($func)(M.vals)) end -Base.copy(M::PermMatrix) = PermMatrix(copy(M.perm), copy(M.vals)) +Base.copy(M::AbstractPermMatrix) = basetype(M)(copy(M.perm), copy(M.vals)) +Base.conj!(M::AbstractPermMatrix) = (conj!(M.vals); M) -function Base.transpose(M::PermMatrix) +function Base.transpose(M::AbstractPermMatrix) new_perm = fast_invperm(M.perm) - return PermMatrix(new_perm, M.vals[new_perm]) + return basetype(M)(new_perm, M.vals[new_perm]) end -Base.adjoint(S::PermMatrix{<:Real}) = transpose(S) -Base.adjoint(S::PermMatrix{<:Complex}) = conj(transpose(S)) +Base.adjoint(S::AbstractPermMatrix{<:Real}) = transpose(S) +Base.adjoint(S::AbstractPermMatrix{<:Complex}) = conj!(transpose(S)) # scalar Base.:*(A::IMatrix{T}, B::Number) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n)) @@ -27,14 +28,14 @@ Base.:*(B::Number, A::IMatrix{T}) where {T} = Diagonal(fill(promote_type(T, elty Base.:/(A::IMatrix{T}, B::Number) where {T} = Diagonal(fill(promote_type(T, eltype(B))(1 / B), A.n)) -Base.:*(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals * B) -Base.:*(B::Number, A::PermMatrix) = A * B -Base.:/(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals / B) +Base.:*(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals * B) +Base.:*(B::Number, A::AbstractPermMatrix) = A * B +Base.:/(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals / B) #+(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv+B.dv, A.ev+B.ev) #-(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv-B.dv, A.ev-B.ev) for op in [:+, :-] - for MT in [:IMatrix, :PermMatrix] + for MT in [:IMatrix, :AbstractPermMatrix] @eval begin # IMatrix, PermMatrix - SparseMatrixCSC Base.$op(A::$MT, B::SparseMatrixCSC) = $op(SparseMatrixCSC(A), B) @@ -45,12 +46,12 @@ for op in [:+, :-] # IMatrix, PermMatrix - Diagonal Base.$op(d1::IMatrix, d2::Diagonal) = Diagonal($op(diag(d1), d2.diag)) Base.$op(d1::Diagonal, d2::IMatrix) = Diagonal($op(d1.diag, diag(d2))) - Base.$op(d1::PermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2) - Base.$op(d1::Diagonal, d2::PermMatrix) = $op(d1, SparseMatrixCSC(d2)) + Base.$op(d1::AbstractPermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2) + Base.$op(d1::Diagonal, d2::AbstractPermMatrix) = $op(d1, SparseMatrixCSC(d2)) # PermMatrix - IMatrix - Base.$op(A::PermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) - Base.$op(A::IMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) - Base.$op(A::PermMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) + Base.$op(A::AbstractPermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) + Base.$op(A::IMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) + Base.$op(A::AbstractPermMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B)) end end # NOTE: promote to integer @@ -59,22 +60,22 @@ Base.:+(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} = Base.:-(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} = d1 == d2 ? spzeros(promote_type(Ta, Tb), d1.n, d1.n) : throw(DimensionMismatch()) -for MT in [:IMatrix, :PermMatrix] +for MT in [:IMatrix, :AbstractPermMatrix] @eval Base.:(==)(A::$MT, B::SparseMatrixCSC) = SparseMatrixCSC(A) == B @eval Base.:(==)(A::SparseMatrixCSC, B::$MT) = A == SparseMatrixCSC(B) end Base.:(==)(d1::IMatrix, d2::Diagonal) = all(isone, d2.diag) Base.:(==)(d1::Diagonal, d2::IMatrix) = all(isone, d1.diag) -Base.:(==)(d1::PermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) -Base.:(==)(d1::Diagonal, d2::PermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) -Base.:(==)(A::IMatrix, B::PermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B) -Base.:(==)(A::PermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B) +Base.:(==)(d1::AbstractPermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) +Base.:(==)(d1::Diagonal, d2::AbstractPermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2) +Base.:(==)(A::IMatrix, B::AbstractPermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B) +Base.:(==)(A::AbstractPermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B) -for MT in [:IMatrix, :PermMatrix] +for MT in [:IMatrix, :AbstractPermMatrix] @eval Base.isapprox(A::$MT, B::SparseMatrixCSC; kwargs...) = isapprox(SparseMatrixCSC(A), B) @eval Base.isapprox(A::SparseMatrixCSC, B::$MT; kwargs...) = isapprox(A, SparseMatrixCSC(B)) @eval Base.isapprox(d1::$MT, d2::Diagonal; kwargs...) = isapprox(diag(d1), d2.diag) @eval Base.isapprox(d1::Diagonal, d2::$MT; kwargs...) = isapprox(d1.diag, diag(d2)) end -Base.isapprox(A::IMatrix, B::PermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...) -Base.isapprox(A::PermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...) +Base.isapprox(A::IMatrix, B::AbstractPermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...) +Base.isapprox(A::AbstractPermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...) diff --git a/src/broadcast.jl b/src/broadcast.jl index 77c5c0a..ba476e8 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -42,12 +42,10 @@ Broadcast.broadcasted( ) = Diagonal(fill(a, b.n)) # specialize perm matrix -function _broadcast_perm_prod(A::PermMatrix, B::AbstractMatrix) +function _broadcast_perm_prod(A::AbstractPermMatrix, B::AbstractMatrix) dest = similar(A, Base.promote_op(*, eltype(A), eltype(B))) - i = 1 - @inbounds for j in dest.perm - dest[i, j] = A[i, j] * B[i, j] - i += 1 + @inbounds for ((i, j), a) in IterNz(A) + dest[i, j] = a * B[i, j] end return dest end @@ -55,40 +53,30 @@ end Broadcast.broadcasted( ::AbstractArrayStyle{2}, ::typeof(*), - A::PermMatrix, + A::AbstractPermMatrix, B::AbstractMatrix, ) = _broadcast_perm_prod(A, B) Broadcast.broadcasted( ::AbstractArrayStyle{2}, ::typeof(*), A::AbstractMatrix, - B::PermMatrix, + B::AbstractPermMatrix, ) = _broadcast_perm_prod(B, A) -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::PermMatrix, B::PermMatrix) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::AbstractPermMatrix) = _broadcast_perm_prod(A, B) -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::PermMatrix, B::IMatrix) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::IMatrix) = Diagonal(A) -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::IMatrix, B::PermMatrix) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::IMatrix, B::AbstractPermMatrix) = Diagonal(B) -function _broadcast_diag_perm_prod(A::Diagonal, B::PermMatrix) - dest = similar(A) - i = 1 - @inbounds for j in B.perm - if i == j - dest[i, i] = A[i, i] * B[i, i] - else - dest[i, i] = 0 - end - i += 1 - end - return dest +function _broadcast_diag_perm_prod(A::Diagonal, B::AbstractPermMatrix) + Diagonal(A.diag .* getindex.(Ref(B), 1:size(A, 1))) end -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::PermMatrix, B::Diagonal) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::Diagonal) = _broadcast_diag_perm_prod(B, A) -Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::Diagonal, B::PermMatrix) = +Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::Diagonal, B::AbstractPermMatrix) = _broadcast_diag_perm_prod(A, B) # TODO: commit this upstream @@ -110,13 +98,13 @@ Broadcast.broadcasted( Broadcast.broadcasted( ::AbstractArrayStyle{2}, ::typeof(*), - a::PermMatrix, + a::AbstractPermMatrix, b::Number, -) = PermMatrix(a.perm, a.vals .* b) +) = basetype(a)(a.perm, a.vals .* b) Broadcast.broadcasted( ::AbstractArrayStyle{2}, ::typeof(*), a::Number, - b::PermMatrix, -) = PermMatrix(b.perm, a .* b.vals) + b::AbstractPermMatrix, +) = basetype(b)(b.perm, a .* b.vals) diff --git a/src/conversions.jl b/src/conversions.jl index c06cc24..a4f167c 100644 --- a/src/conversions.jl +++ b/src/conversions.jl @@ -16,14 +16,10 @@ SparseMatrixCSC{Tv,Ti}(A::IMatrix) where {Tv,Ti<:Integer} = SparseMatrixCSC{Tv,Ti}(I, A.n, A.n) SparseMatrixCSC{Tv}(A::IMatrix) where {Tv} = SparseMatrixCSC{Tv,Int}(A) SparseMatrixCSC(A::IMatrix{T}) where {T} = SparseMatrixCSC{T,Int}(I, A.n, A.n) -function SparseMatrixCSC(M::PermMatrix) +function SparseMatrixCSC(M::AbstractPermMatrix) n = size(M, 1) - order = invperm(M.perm) - SparseMatrixCSC(n, n, collect(1:n+1), order, M.vals[order]) -end -function SparseMatrixCSC(M::PermMatrixCSC) - n = size(M, 1) - SparseMatrixCSC(n, n, collect(1:n+1), M.perm, M.vals[order]) + MC = PermMatrixCSC(M) + SparseMatrixCSC(n, n, collect(1:n+1), MC.perm, MC.vals) end @static if VERSION < v"1.3-" @@ -69,14 +65,20 @@ function Matrix(coo::SparseMatrixCOO{T}) where {T} end ################## To PermMatrix ###################### -PermMatrix(pc::PermMatrixCSC) = PermMatrix(invperm(pc.perm), pc.vals) -PermMatrixCSC(pc::PermMatrix) = PermMatrixCSC(invperm(pc.perm), pc.vals) +function PermMatrix(pc::PermMatrixCSC) + order = fast_invperm(pc.perm) + PermMatrix(order, pc.vals[order]) +end +function PermMatrixCSC(pc::PermMatrix) + order = fast_invperm(pc.perm) + PermMatrixCSC(order, pc.vals[order]) +end for MT in [:PermMatrix, :PermMatrixCSC] - $MT{Tv,Ti}(A::IMatrix) where {Tv,Ti} = + @eval $MT{Tv,Ti}(A::IMatrix) where {Tv,Ti} = $MT{Tv,Ti}(Vector{Ti}(1:A.n), ones(Tv, A.n)) - $MT{Tv}(X::IMatrix) where {Tv} = $MT{Tv,Int}(X) - $MT(X::IMatrix{T}) where {T} = $MT{T,Int}(X) - $MT{Tv,Ti}(A::$MT) where {Tv,Ti} = + @eval $MT{Tv}(X::IMatrix) where {Tv} = $MT{Tv,Int}(X) + @eval $MT(X::IMatrix{T}) where {T} = $MT{T,Int}(X) + @eval $MT{Tv,Ti}(A::$MT) where {Tv,Ti} = $MT(Vector{Ti}(A.perm), Vector{Tv}(A.vals)) end @@ -89,15 +91,12 @@ end _findnz(A::AbstractSparseArray) = findnz(A) function PermMatrix{Tv,Ti}(A::AbstractMatrix) where {Tv,Ti} - i, j, v = _findnz(A) - j == collect(1:size(A, 2)) || throw(ArgumentError("This is not a PermMatrix")) - order = invperm(i) - PermMatrix{Tv,Ti}(Vector{Ti}(order), Vector{Tv}(v[order])) + PermMatrix(PermMatrixCSC(A)) end function PermMatrixCSC{Tv,Ti}(A::AbstractMatrix) where {Tv,Ti} i, j, v = _findnz(A) j == collect(1:size(A, 2)) || throw(ArgumentError("This is not a PermMatrix")) - PermMatrix{Tv,Ti}(Vector{Ti}(i), Vector{Tv}(v[order])) + PermMatrix{Tv,Ti}(Vector{Ti}(i), Vector{Tv}(v)) end for MT in [:PermMatrix, :PermMatrixCSC] diff --git a/src/kronecker.jl b/src/kronecker.jl index 81ed79a..7d7e909 100644 --- a/src/kronecker.jl +++ b/src/kronecker.jl @@ -31,12 +31,6 @@ LinearAlgebra.kron(A::IMatrix{Ta}, B::IMatrix{Tb}) where {Ta<:Number,Tb<:Number} LinearAlgebra.kron(A::IMatrix{<:Number}, B::Diagonal{<:Number}) = A.n == 1 ? B : Diagonal(orepeat(B.diag, A.n)) LinearAlgebra.kron(B::Diagonal{<:Number}, A::IMatrix) = A.n == 1 ? B : Diagonal(irepeat(B.diag, A.n)) -####### diagonal kron ######## -LinearAlgebra.kron(A::StridedMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B)) -LinearAlgebra.kron(A::Diagonal{<:Number}, B::StridedMatrix{<:Number}) = kron(PermMatrix(A), B) -LinearAlgebra.kron(A::Diagonal{<:Number}, B::SparseMatrixCSC{<:Number}) = kron(PermMatrix(A), B) -LinearAlgebra.kron(A::SparseMatrixCSC{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B)) - function LinearAlgebra.kron(A::AbstractMatrix{Tv}, B::IMatrix) where {Tv<:Number} B.n == 1 && return A mA, nA = size(A) @@ -127,7 +121,7 @@ function LinearAlgebra.kron(A::SparseMatrixCSC{T}, B::IMatrix) where {T<:Number} SparseMatrixCSC(mA * B.n, nA * B.n, colptr, rowval, nzval) end -function LinearAlgebra.kron(A::PermMatrix{T}, B::IMatrix) where {T<:Number} +function LinearAlgebra.kron(A::AbstractPermMatrix{T}, B::IMatrix) where {T<:Number} nA = size(A, 1) nB = size(B, 1) nB == 1 && return A @@ -142,10 +136,10 @@ function LinearAlgebra.kron(A::PermMatrix{T}, B::IMatrix) where {T<:Number} vals[start+j] = val end end - PermMatrix(perm, vals) + basetype(A)(perm, vals) end -function LinearAlgebra.kron(A::IMatrix, B::PermMatrix{Tv,Ti}) where {Tv<:Number,Ti<:Integer} +function LinearAlgebra.kron(A::IMatrix, B::AbstractPermMatrix{Tv,Ti}) where {Tv<:Number,Ti<:Integer} nA = size(A, 1) nB = size(B, 1) nA == 1 && return B @@ -158,14 +152,14 @@ function LinearAlgebra.kron(A::IMatrix, B::PermMatrix{Tv,Ti}) where {Tv<:Number, vals[start+j] = B.vals[j] end end - PermMatrix(perm, vals) + basetype(B)(perm, vals) end - -function LinearAlgebra.kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv<:Number,Tb<:Number} +function LinearAlgebra.kron(A::StridedMatrix{Tv}, B::AbstractPermMatrix{Tb}) where {Tv<:Number,Tb<:Number} mA, nA = size(A) nB = size(B, 1) - perm = fast_invperm(B.perm) + BC = PermMatrixCSC(B) + perm, vals = BC.perm, BC.vals nzval = Vector{promote_type(Tv, Tb)}(undef, mA * nA * nB) rowval = Vector{Int}(undef, mA * nA * nB) colptr = collect(1:mA:nA*nB*mA+1) @@ -173,7 +167,7 @@ function LinearAlgebra.kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv<: @inbounds for j = 1:nA @inbounds for j2 = 1:nB p2 = perm[j2] - val2 = B.vals[p2] + val2 = vals[j2] ir = p2 @inbounds @simd for i = 1:mA nzval[z] = A[i, j] * val2 # merge @@ -186,18 +180,18 @@ function LinearAlgebra.kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv<: SparseMatrixCSC(mA * nB, nA * nB, colptr, rowval, nzval) end -function LinearAlgebra.kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb<:Number,Ta<:Number} +function LinearAlgebra.kron(A::AbstractPermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb<:Number,Ta<:Number} mB, nB = size(B) nA = size(A, 1) - perm = fast_invperm(A.perm) + AC = PermMatrixCSC(A) + perm, vals = AC.perm, AC.vals nzval = Vector{promote_type(Ta, Tb)}(undef, mB * nA * nB) rowval = Vector{Int}(undef, mB * nA * nB) colptr = collect(1:mB:nA*nB*mB+1) z = 0 @inbounds for j = 1:nA - colbase = (j - 1) * nB p1 = perm[j] - val2 = A.vals[p1] + val2 = vals[j] ir = (p1 - 1) * mB for j2 = 1:nB @inbounds @simd for i2 = 1:mB @@ -210,7 +204,8 @@ function LinearAlgebra.kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb<: SparseMatrixCSC(nA * mB, nA * nB, colptr, rowval, nzval) end -function LinearAlgebra.kron(A::PermMatrix{<:Number}, B::PermMatrix{<:Number}) +function LinearAlgebra.kron(A::AbstractPermMatrix{<:Number}, B::AbstractPermMatrix{<:Number}) + @assert basetype(A) == basetype(B) nA = size(A, 1) nB = size(B, 1) vals = kron(A.vals, B.vals) @@ -222,17 +217,18 @@ function LinearAlgebra.kron(A::PermMatrix{<:Number}, B::PermMatrix{<:Number}) perm[start+j] = permAi + B.perm[j] end end - PermMatrix(perm, vals) + basetype(A)(perm, vals) end -LinearAlgebra.kron(A::PermMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B)) -LinearAlgebra.kron(A::Diagonal{<:Number}, B::PermMatrix{<:Number}) = kron(PermMatrix(A), B) +LinearAlgebra.kron(A::AbstractPermMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, basetype(A)(B)) +LinearAlgebra.kron(A::Diagonal{<:Number}, B::AbstractPermMatrix{<:Number}) = kron(basetype(B)(A), B) -function LinearAlgebra.kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta<:Number,Tb<:Number} +function LinearAlgebra.kron(A::AbstractPermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta<:Number,Tb<:Number} nA = size(A, 1) mB, nB = size(B) nV = nnz(B) - perm = fast_invperm(A.perm) + AC = PermMatrixCSC(A) + perm, vals = AC.perm, AC.vals nzval = Vector{promote_type(Ta, Tb)}(undef, nA * nV) rowval = Vector{Int}(undef, nA * nV) colptr = Vector{Int}(undef, nA * nB + 1) @@ -240,7 +236,7 @@ function LinearAlgebra.kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta @inbounds @simd for i = 1:nA start_row = (i - 1) * nV start_ri = (perm[i] - 1) * mB - v0 = A.vals[perm[i]] + v0 = vals[i] @inbounds @simd for j = 1:nV nzval[start_row+j] = B.nzval[j] * v0 rowval[start_row+j] = B.rowval[j] + start_ri @@ -254,11 +250,12 @@ function LinearAlgebra.kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta SparseMatrixCSC(mB * nA, nB * nA, colptr, rowval, nzval) end -function LinearAlgebra.kron(A::SparseMatrixCSC{T}, B::PermMatrix{Tb}) where {T<:Number,Tb<:Number} +function LinearAlgebra.kron(A::SparseMatrixCSC{T}, B::AbstractPermMatrix{Tb}) where {T<:Number,Tb<:Number} nB = size(B, 1) mA, nA = size(A) nV = nnz(A) - perm = fast_invperm(B.perm) + BC = PermMatrixCSC(B) + perm, vals = BC.perm, BC.vals rowval = Vector{Int}(undef, nB * nV) colptr = Vector{Int}(undef, nA * nB + 1) nzval = Vector{promote_type(T, Tb)}(undef, nB * nV) @@ -269,7 +266,7 @@ function LinearAlgebra.kron(A::SparseMatrixCSC{T}, B::PermMatrix{Tb}) where {T<: rend = A.colptr[i+1] - 1 @inbounds for k = 1:nB irow = perm[k] - bval = B.vals[irow] + bval = vals[k] irow_nB = irow - nB @inbounds @simd for r = rstart:rend rowval[z] = A.rowval[r] * nB + irow_nB diff --git a/src/linalg.jl b/src/linalg.jl index 39da4da..edc3a53 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -3,15 +3,14 @@ Base.inv(M::IMatrix) = M LinearAlgebra.det(M::IMatrix) = 1 LinearAlgebra.diag(M::IMatrix{T}) where {T} = ones(T, M.n) LinearAlgebra.logdet(M::IMatrix) = 0 -Base.sqrt(x::PermMatrix) = sqrt(Matrix(x)) +Base.sqrt(x::AbstractPermMatrix) = sqrt(Matrix(x)) Base.sqrt(x::IMatrix) = x -Base.exp(x::PermMatrix) = exp(Matrix(x)) +Base.exp(x::AbstractPermMatrix) = exp(Matrix(x)) Base.exp(x::IMatrix) = ℯ * x -#det(M::PermMatrix) = parity(M.perm)*prod(M.vals) -function Base.inv(M::PermMatrix) +function Base.inv(M::AbstractPermMatrix) new_perm = fast_invperm(M.perm) - return PermMatrix(new_perm, 1.0 ./ M.vals[new_perm]) + return basetype(M)(new_perm, 1.0 ./ M.vals[new_perm]) end ####### multiply ########### @@ -24,7 +23,7 @@ Base.:*(A::IMatrix, B::AbstractVector) = ) for MATTYPE in - [:AbstractMatrix, :StridedMatrix, :Diagonal, :SparseMatrixCSC, :Matrix, :PermMatrix] + [:AbstractMatrix, :StridedMatrix, :Diagonal, :SparseMatrixCSC, :Matrix, :AbstractPermMatrix] @eval Base.:*(A::IMatrix, B::$MATTYPE) = A.n == size(B, 1) ? B : throw( @@ -61,12 +60,12 @@ Base.:*(A::IMatrix, B::IMatrix) = ########## Multiplication ############# -function LinearAlgebra.mul!(Y::AbstractVector, A::PermMatrix, X::AbstractVector, alpha::Number, beta::Number) - length(X) == size(A, 2) || throw(DimensionMismatch("input X length does not match PermMatrix A")) - length(Y) == size(A, 2) || throw(DimensionMismatch("output Y length does not match PermMatrix A")) +function LinearAlgebra.mul!(Y::AbstractVector, A::AbstractPermMatrix, X::AbstractVector, alpha::Number, beta::Number) + length(X) == size(A, 2) || throw(DimensionMismatch("input X length does not match permutation matrix A")) + length(Y) == size(A, 2) || throw(DimensionMismatch("output Y length does not match permutation matrix A")) - @inbounds for I in eachindex(X) - Y[I] = A.vals[I] * X[A.perm[I]] * alpha + beta * Y[I] + @inbounds for ((i, j), p) in IterNz(A) + Y[i] = p * X[j] * alpha + beta * Y[i] end return Y end @@ -75,54 +74,50 @@ end function Base.:*(D::Diagonal{Td}, A::PermMatrix{Ta}) where {Td,Ta} PermMatrix(A.perm, A.vals .* D.diag) end - +function Base.:*(D::Diagonal{Td}, A::PermMatrixCSC{Ta}) where {Td,Ta} + PermMatrixCSC(A.perm, view(D.diag, A.perm) .* A.vals) +end function Base.:*(A::PermMatrix{Ta}, D::Diagonal{Td}) where {Td,Ta} PermMatrix(A.perm, A.vals .* view(D.diag, A.perm)) end +function Base.:*(A::PermMatrixCSC{Ta}, D::Diagonal{Td}) where {Td,Ta} + PermMatrixCSC(A.perm, A.vals .* D.diag) +end # to self -function Base.:*(A::PermMatrix, B::PermMatrix) +function Base.:*(A::AbstractPermMatrix, B::AbstractPermMatrix) + @assert basetype(A) == basetype(B) size(A, 1) == size(B, 1) || throw(DimensionMismatch()) PermMatrix(B.perm[A.perm], A.vals .* view(B.vals, A.perm)) end # to matrix -function Base.:*(A::PermMatrix, X::AbstractMatrix) +function LinearAlgebra.:mul!(C::AbstractMatrix, A::AbstractPermMatrix, X::AbstractMatrix, alpha::Number, beta::Number) size(X, 1) == size(A, 2) || throw(DimensionMismatch()) - return A.vals .* view(X,A.perm,:) # this may be inefficient for sparse CSC matrix. + AR = PermMatrix(A) + C .= C .* beta .+ AR.vals .* view(X,AR.perm,:) .* alpha end - -function Base.:*(X::AbstractMatrix, A::PermMatrix) - mX, nX = size(X) - nX == size(A, 1) || throw(DimensionMismatch()) - perm = fast_invperm(A.perm) - return transpose(view(A.vals, perm)) .* view(X, :, perm) +function LinearAlgebra.mul!(C::AbstractMatrix, X::AbstractMatrix, A::AbstractPermMatrix, alpha::Number, beta::Number) + size(X, 2) == size(A, 1) || throw(DimensionMismatch()) + AC = PermMatrixCSC(A) + C .= C .* beta .+ reshape(AC.vals, 1, :) .* view(X, :, perm) .* alpha end # NOTE: this is just a temperory fix for v0.7. We should overload mul! in # the future (when we start to drop v0.6) to enable buildin lazy evaluation. -Base.:*(x::Adjoint{<:Any,<:AbstractVector}, D::PermMatrix) = Matrix(x) * D -Base.:*(x::Transpose{<:Any,<:AbstractVector}, D::PermMatrix) = Matrix(x) * D -Base.:*(A::Adjoint{<:Any,<:AbstractArray}, D::PermMatrix) = Adjoint(adjoint(D) * parent(A)) -Base.:*(A::Transpose{<:Any,<:AbstractArray}, D::PermMatrix) = Transpose(transpose(D) * parent(A)) -Base.:*(A::Adjoint{<:Any,<:PermMatrix}, D::PermMatrix) = adjoint(parent(A)) * D -Base.:*(A::Transpose{<:Any,<:PermMatrix}, D::PermMatrix) = transpose(parent(A)) * D -Base.:*(A::PermMatrix, D::Adjoint{<:Any,<:PermMatrix}) = A * adjoint(parent(D)) -Base.:*(A::PermMatrix, D::Transpose{<:Any,<:PermMatrix}) = A * transpose(parent(D)) - -# for MAT in [:AbstractArray, :Matrix, :SparseMatrixCSC, :PermMatrix] -# @eval begin -# *(A::Adjoint{<:Any, <:$MAT}, D::PermMatrix) = copy(A) * D -# *(A::Transpose{<:Any, <:$MAT}, D::PermMatrix) = copy(A) * D -# *(A::PermMatrix, D::Adjoint{<:Any, <:$MAT}) = A * copy(D) -# *(A::PermMatrix, D::Transpose{<:Any, <:$MAT}) = A * copy(D) -# end -# end +Base.:*(x::Adjoint{<:Any,<:AbstractVector}, D::AbstractPermMatrix) = Matrix(x) * D +Base.:*(x::Transpose{<:Any,<:AbstractVector}, D::AbstractPermMatrix) = Matrix(x) * D +Base.:*(A::Adjoint{<:Any,<:AbstractArray}, D::AbstractPermMatrix) = Adjoint(adjoint(D) * parent(A)) +Base.:*(A::Transpose{<:Any,<:AbstractArray}, D::AbstractPermMatrix) = Transpose(transpose(D) * parent(A)) +Base.:*(A::Adjoint{<:Any,<:AbstractPermMatrix}, D::AbstractPermMatrix) = adjoint(parent(A)) * D +Base.:*(A::Transpose{<:Any,<:AbstractPermMatrix}, D::AbstractPermMatrix) = transpose(parent(A)) * D +Base.:*(A::AbstractPermMatrix, D::Adjoint{<:Any,<:AbstractPermMatrix}) = A * adjoint(parent(D)) +Base.:*(A::AbstractPermMatrix, D::Transpose{<:Any,<:AbstractPermMatrix}) = A * transpose(parent(D)) ############### Transpose, Adjoint for IMatrix ############### for MAT in - [:AbstractArray, :AbstractVector, :Matrix, :SparseMatrixCSC, :PermMatrix, :IMatrix] + [:AbstractArray, :AbstractVector, :Matrix, :SparseMatrixCSC, :AbstractPermMatrix, :IMatrix] @eval Base.:*(A::Adjoint{<:Any,<:$MAT}, D::IMatrix) = Adjoint(D * parent(A)) @eval Base.:*(A::Transpose{<:Any,<:$MAT}, D::IMatrix) = Transpose(D * parent(A)) if MAT != :AbstactVector @@ -132,17 +127,18 @@ for MAT in end # to sparse -function Base.:*(A::PermMatrix, X::SparseMatrixCSC) +function Base.:*(A::AbstractPermMatrix, X::SparseMatrixCSC) nA = size(A, 1) mX, nX = size(X) mX == nA || throw(DimensionMismatch()) - perm = fast_invperm(A.perm) + AC = PermMatrixCSC(A) + perm, vals = AC.perm, AC.vals nzval = similar(X.nzval) rowval = similar(X.rowval) @inbounds for j = 1:nX @inbounds for k = X.colptr[j]:X.colptr[j+1]-1 r = perm[X.rowval[k]] - nzval[k] = X.nzval[k] * A.vals[r] + nzval[k] = X.nzval[k] * vals[X.rowval[k]] rowval[k] = r end end @@ -150,11 +146,12 @@ function Base.:*(A::PermMatrix, X::SparseMatrixCSC) SparseMatrixCSC(sp')' end -function Base.:*(X::SparseMatrixCSC, A::PermMatrix) +function Base.:*(X::SparseMatrixCSC, A::AbstractPermMatrix) nA = size(A, 1) mX, nX = size(X) nX == nA || throw(DimensionMismatch()) - perm = fast_invperm(A.perm) + AC = PermMatrixCSC(A) + perm, vals = AC.perm, AC.vals nzval = similar(X.nzval) colptr = similar(X.colptr) rowval = similar(X.rowval) @@ -162,7 +159,7 @@ function Base.:*(X::SparseMatrixCSC, A::PermMatrix) z = 1 @inbounds for j = 1:nA pk = perm[j] - va = A.vals[pk] + va = vals[j] @inbounds @simd for k = X.colptr[pk]:X.colptr[pk+1]-1 nzval[z] = X.nzval[k] * va rowval[z] = X.rowval[k]