Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Feb 25, 2024
1 parent 7ecb42d commit 51e52ff
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 145 deletions.
6 changes: 3 additions & 3 deletions src/PermMatrix.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
abstract type AbstractPermMatrix{Tv} <: AbstractMatrix{Tv} end
abstract type AbstractPermMatrix{Tv, Ti} <: AbstractMatrix{Tv} end
"""
PermMatrix{Tv, Ti}(perm::AbstractVector{Ti}, vals::AbstractVector{Tv}) where {Tv, Ti<:Integer}
PermMatrix(perm::Vector{Ti}, vals::Vector{Tv}) where {Tv, Ti}
Expand All @@ -25,7 +25,7 @@ julia> PermMatrix([2,1,4,3], rand(4))
```
"""
struct PermMatrix{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <:
AbstractPermMatrix{Tv}
AbstractPermMatrix{Tv,Ti}
perm::Vi # new orders
vals::Vv # multiplied values.

Expand Down Expand Up @@ -56,7 +56,7 @@ end

# the column major version of `PermMatrix`
struct PermMatrixCSC{Tv,Ti<:Integer,Vv<:AbstractVector{Tv},Vi<:AbstractVector{Ti}} <:
AbstractPermMatrix{Tv}
AbstractPermMatrix{Tv,Ti}
perm::Vi # new orders
vals::Vv # multiplied values.

Expand Down
47 changes: 24 additions & 23 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,33 @@ Base.imag(M::IMatrix{T}) where {T} = Diagonal(zeros(T, M.n))

# PermMatrix
for func in (:conj, :real, :imag)
@eval (Base.$func)(M::PermMatrix) = PermMatrix(M.perm, ($func)(M.vals))
@eval (Base.$func)(M::AbstractPermMatrix) = basetype(M)(M.perm, ($func)(M.vals))
end
Base.copy(M::PermMatrix) = PermMatrix(copy(M.perm), copy(M.vals))
Base.copy(M::AbstractPermMatrix) = basetype(M)(copy(M.perm), copy(M.vals))
Base.conj!(M::AbstractPermMatrix) = (conj!(M.vals); M)

function Base.transpose(M::PermMatrix)
function Base.transpose(M::AbstractPermMatrix)
new_perm = fast_invperm(M.perm)
return PermMatrix(new_perm, M.vals[new_perm])
return basetype(M)(new_perm, M.vals[new_perm])
end

Base.adjoint(S::PermMatrix{<:Real}) = transpose(S)
Base.adjoint(S::PermMatrix{<:Complex}) = conj(transpose(S))
Base.adjoint(S::AbstractPermMatrix{<:Real}) = transpose(S)
Base.adjoint(S::AbstractPermMatrix{<:Complex}) = conj!(transpose(S))

# scalar
Base.:*(A::IMatrix{T}, B::Number) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
Base.:*(B::Number, A::IMatrix{T}) where {T} = Diagonal(fill(promote_type(T, eltype(B))(B), A.n))
Base.:/(A::IMatrix{T}, B::Number) where {T} =
Diagonal(fill(promote_type(T, eltype(B))(1 / B), A.n))

Base.:*(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals * B)
Base.:*(B::Number, A::PermMatrix) = A * B
Base.:/(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals / B)
Base.:*(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals * B)
Base.:*(B::Number, A::AbstractPermMatrix) = A * B
Base.:/(A::AbstractPermMatrix, B::Number) = basetype(A)(A.perm, A.vals / B)
#+(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv+B.dv, A.ev+B.ev)
#-(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv-B.dv, A.ev-B.ev)

for op in [:+, :-]
for MT in [:IMatrix, :PermMatrix]
for MT in [:IMatrix, :AbstractPermMatrix]
@eval begin
# IMatrix, PermMatrix - SparseMatrixCSC
Base.$op(A::$MT, B::SparseMatrixCSC) = $op(SparseMatrixCSC(A), B)
Expand All @@ -45,12 +46,12 @@ for op in [:+, :-]
# IMatrix, PermMatrix - Diagonal
Base.$op(d1::IMatrix, d2::Diagonal) = Diagonal($op(diag(d1), d2.diag))
Base.$op(d1::Diagonal, d2::IMatrix) = Diagonal($op(d1.diag, diag(d2)))
Base.$op(d1::PermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2)
Base.$op(d1::Diagonal, d2::PermMatrix) = $op(d1, SparseMatrixCSC(d2))
Base.$op(d1::AbstractPermMatrix, d2::Diagonal) = $op(SparseMatrixCSC(d1), d2)
Base.$op(d1::Diagonal, d2::AbstractPermMatrix) = $op(d1, SparseMatrixCSC(d2))
# PermMatrix - IMatrix
Base.$op(A::PermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::IMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::PermMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::AbstractPermMatrix, B::IMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::IMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
Base.$op(A::AbstractPermMatrix, B::AbstractPermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
end
end
# NOTE: promote to integer
Expand All @@ -59,22 +60,22 @@ Base.:+(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} =
Base.:-(d1::IMatrix{Ta}, d2::IMatrix{Tb}) where {Ta,Tb} =
d1 == d2 ? spzeros(promote_type(Ta, Tb), d1.n, d1.n) : throw(DimensionMismatch())

for MT in [:IMatrix, :PermMatrix]
for MT in [:IMatrix, :AbstractPermMatrix]
@eval Base.:(==)(A::$MT, B::SparseMatrixCSC) = SparseMatrixCSC(A) == B
@eval Base.:(==)(A::SparseMatrixCSC, B::$MT) = A == SparseMatrixCSC(B)
end
Base.:(==)(d1::IMatrix, d2::Diagonal) = all(isone, d2.diag)
Base.:(==)(d1::Diagonal, d2::IMatrix) = all(isone, d1.diag)
Base.:(==)(d1::PermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(d1::Diagonal, d2::PermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(A::IMatrix, B::PermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)
Base.:(==)(A::PermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)
Base.:(==)(d1::AbstractPermMatrix, d2::Diagonal) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(d1::Diagonal, d2::AbstractPermMatrix) = SparseMatrixCSC(d1) == SparseMatrixCSC(d2)
Base.:(==)(A::IMatrix, B::AbstractPermMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)
Base.:(==)(A::AbstractPermMatrix, B::IMatrix) = SparseMatrixCSC(A) == SparseMatrixCSC(B)

for MT in [:IMatrix, :PermMatrix]
for MT in [:IMatrix, :AbstractPermMatrix]
@eval Base.isapprox(A::$MT, B::SparseMatrixCSC; kwargs...) = isapprox(SparseMatrixCSC(A), B)
@eval Base.isapprox(A::SparseMatrixCSC, B::$MT; kwargs...) = isapprox(A, SparseMatrixCSC(B))
@eval Base.isapprox(d1::$MT, d2::Diagonal; kwargs...) = isapprox(diag(d1), d2.diag)
@eval Base.isapprox(d1::Diagonal, d2::$MT; kwargs...) = isapprox(d1.diag, diag(d2))
end
Base.isapprox(A::IMatrix, B::PermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Base.isapprox(A::PermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Base.isapprox(A::IMatrix, B::AbstractPermMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
Base.isapprox(A::AbstractPermMatrix, B::IMatrix; kwargs...) = isapprox(SparseMatrixCSC(A), SparseMatrixCSC(B); kwargs...)
44 changes: 16 additions & 28 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,53 +42,41 @@ Broadcast.broadcasted(
) = Diagonal(fill(a, b.n))

# specialize perm matrix
function _broadcast_perm_prod(A::PermMatrix, B::AbstractMatrix)
function _broadcast_perm_prod(A::AbstractPermMatrix, B::AbstractMatrix)
dest = similar(A, Base.promote_op(*, eltype(A), eltype(B)))
i = 1
@inbounds for j in dest.perm
dest[i, j] = A[i, j] * B[i, j]
i += 1
@inbounds for ((i, j), a) in IterNz(A)
dest[i, j] = a * B[i, j]
end
return dest
end

Broadcast.broadcasted(
::AbstractArrayStyle{2},
::typeof(*),
A::PermMatrix,
A::AbstractPermMatrix,
B::AbstractMatrix,
) = _broadcast_perm_prod(A, B)
Broadcast.broadcasted(
::AbstractArrayStyle{2},
::typeof(*),
A::AbstractMatrix,
B::PermMatrix,
B::AbstractPermMatrix,
) = _broadcast_perm_prod(B, A)
Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::PermMatrix, B::PermMatrix) =
Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::AbstractPermMatrix) =
_broadcast_perm_prod(A, B)

Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::PermMatrix, B::IMatrix) =
Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::IMatrix) =
Diagonal(A)
Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::IMatrix, B::PermMatrix) =
Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::IMatrix, B::AbstractPermMatrix) =
Diagonal(B)

function _broadcast_diag_perm_prod(A::Diagonal, B::PermMatrix)
dest = similar(A)
i = 1
@inbounds for j in B.perm
if i == j
dest[i, i] = A[i, i] * B[i, i]
else
dest[i, i] = 0
end
i += 1
end
return dest
function _broadcast_diag_perm_prod(A::Diagonal, B::AbstractPermMatrix)
Diagonal(A.diag .* getindex.(Ref(B), 1:size(A, 1)))
end

Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::PermMatrix, B::Diagonal) =
Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::AbstractPermMatrix, B::Diagonal) =
_broadcast_diag_perm_prod(B, A)
Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::Diagonal, B::PermMatrix) =
Broadcast.broadcasted(::AbstractArrayStyle{2}, ::typeof(*), A::Diagonal, B::AbstractPermMatrix) =
_broadcast_diag_perm_prod(A, B)

# TODO: commit this upstream
Expand All @@ -110,13 +98,13 @@ Broadcast.broadcasted(
Broadcast.broadcasted(
::AbstractArrayStyle{2},
::typeof(*),
a::PermMatrix,
a::AbstractPermMatrix,
b::Number,
) = PermMatrix(a.perm, a.vals .* b)
) = basetype(a)(a.perm, a.vals .* b)

Broadcast.broadcasted(
::AbstractArrayStyle{2},
::typeof(*),
a::Number,
b::PermMatrix,
) = PermMatrix(b.perm, a .* b.vals)
b::AbstractPermMatrix,
) = basetype(b)(b.perm, a .* b.vals)
35 changes: 17 additions & 18 deletions src/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@ SparseMatrixCSC{Tv,Ti}(A::IMatrix) where {Tv,Ti<:Integer} =
SparseMatrixCSC{Tv,Ti}(I, A.n, A.n)
SparseMatrixCSC{Tv}(A::IMatrix) where {Tv} = SparseMatrixCSC{Tv,Int}(A)
SparseMatrixCSC(A::IMatrix{T}) where {T} = SparseMatrixCSC{T,Int}(I, A.n, A.n)
function SparseMatrixCSC(M::PermMatrix)
function SparseMatrixCSC(M::AbstractPermMatrix)
n = size(M, 1)
order = invperm(M.perm)
SparseMatrixCSC(n, n, collect(1:n+1), order, M.vals[order])
end
function SparseMatrixCSC(M::PermMatrixCSC)
n = size(M, 1)
SparseMatrixCSC(n, n, collect(1:n+1), M.perm, M.vals[order])
MC = PermMatrixCSC(M)
SparseMatrixCSC(n, n, collect(1:n+1), MC.perm, MC.vals)
end

@static if VERSION < v"1.3-"
Expand Down Expand Up @@ -69,14 +65,20 @@ function Matrix(coo::SparseMatrixCOO{T}) where {T}
end

################## To PermMatrix ######################
PermMatrix(pc::PermMatrixCSC) = PermMatrix(invperm(pc.perm), pc.vals)
PermMatrixCSC(pc::PermMatrix) = PermMatrixCSC(invperm(pc.perm), pc.vals)
function PermMatrix(pc::PermMatrixCSC)
order = fast_invperm(pc.perm)
PermMatrix(order, pc.vals[order])
end
function PermMatrixCSC(pc::PermMatrix)
order = fast_invperm(pc.perm)
PermMatrixCSC(order, pc.vals[order])
end
for MT in [:PermMatrix, :PermMatrixCSC]
$MT{Tv,Ti}(A::IMatrix) where {Tv,Ti} =
@eval $MT{Tv,Ti}(A::IMatrix) where {Tv,Ti} =
$MT{Tv,Ti}(Vector{Ti}(1:A.n), ones(Tv, A.n))
$MT{Tv}(X::IMatrix) where {Tv} = $MT{Tv,Int}(X)
$MT(X::IMatrix{T}) where {T} = $MT{T,Int}(X)
$MT{Tv,Ti}(A::$MT) where {Tv,Ti} =
@eval $MT{Tv}(X::IMatrix) where {Tv} = $MT{Tv,Int}(X)
@eval $MT(X::IMatrix{T}) where {T} = $MT{T,Int}(X)
@eval $MT{Tv,Ti}(A::$MT) where {Tv,Ti} =
$MT(Vector{Ti}(A.perm), Vector{Tv}(A.vals))
end

Expand All @@ -89,15 +91,12 @@ end
_findnz(A::AbstractSparseArray) = findnz(A)

function PermMatrix{Tv,Ti}(A::AbstractMatrix) where {Tv,Ti}
i, j, v = _findnz(A)
j == collect(1:size(A, 2)) || throw(ArgumentError("This is not a PermMatrix"))
order = invperm(i)
PermMatrix{Tv,Ti}(Vector{Ti}(order), Vector{Tv}(v[order]))
PermMatrix(PermMatrixCSC(A))
end
function PermMatrixCSC{Tv,Ti}(A::AbstractMatrix) where {Tv,Ti}
i, j, v = _findnz(A)
j == collect(1:size(A, 2)) || throw(ArgumentError("This is not a PermMatrix"))
PermMatrix{Tv,Ti}(Vector{Ti}(i), Vector{Tv}(v[order]))
PermMatrix{Tv,Ti}(Vector{Ti}(i), Vector{Tv}(v))
end

for MT in [:PermMatrix, :PermMatrixCSC]
Expand Down
Loading

0 comments on commit 51e52ff

Please sign in to comment.