Skip to content

Commit cd28425

Browse files
authored
Improve broadcasting of PDMat and PDiagMat (#197)
1 parent e0cad7c commit cd28425

File tree

5 files changed

+51
-1
lines changed

5 files changed

+51
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "PDMats"
22
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
3-
version = "0.11.29"
3+
version = "0.11.30"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/pdiagmat.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Base.Matrix(a::PDiagMat) = Matrix(Diagonal(a.diag))
3232
LinearAlgebra.diag(a::PDiagMat) = copy(a.diag)
3333
LinearAlgebra.cholesky(a::PDiagMat) = Cholesky(Diagonal(map(sqrt, a.diag)), 'U', 0)
3434

35+
### Treat as a `Diagonal` matrix in broadcasting since that is better supported
36+
Base.broadcastable(a::PDiagMat) = Base.broadcastable(Diagonal(a.diag))
37+
3538
### Inheriting from AbstractMatrix
3639

3740
function Base.getindex(a::PDiagMat, i::Integer)

src/pdmat.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ Base.Matrix(a::PDMat) = Matrix(a.mat)
4646
LinearAlgebra.diag(a::PDMat) = diag(a.mat)
4747
LinearAlgebra.cholesky(a::PDMat) = a.chol
4848

49+
### Work with the underlying matrix in broadcasting
50+
Base.broadcastable(a::PDMat) = Base.broadcastable(a.mat)
51+
4952
### Inheriting from AbstractMatrix
5053

5154
Base.getindex(a::PDMat, i::Int) = getindex(a.mat, i)

test/pdmtypes.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,32 @@ using Test
256256
@test_throws DimensionMismatch PDSparseMat(A[1:(end - 1), 1:(end - 1)], C)
257257
end
258258
end
259+
260+
@testset "Subtraction" begin
261+
# This falls back to the generic method in Julia based on broadcasting
262+
dim = 4
263+
x = rand(dim, dim)
264+
A = PDMat(x' * x + I)
265+
@test Base.broadcastable(A) == A.mat
266+
267+
B = PDiagMat(rand(dim))
268+
@test Base.broadcastable(B) == Diagonal(B.diag)
269+
270+
for X in (A, B), Y in (A, B)
271+
@test X - Y isa (X === Y === B ? Diagonal{Float64, Vector{Float64}} : Matrix{Float64})
272+
@test X - Y Matrix(X) - Matrix(Y)
273+
end
274+
275+
C = ScalMat(dim, rand())
276+
@test A - C isa Matrix{Float64}
277+
@test A - C Matrix(A) - Matrix(C)
278+
@test C - A isa Matrix{Float64}
279+
@test C - A Matrix(C) - Matrix(A)
280+
281+
# ScalMat does not behave nicely with PDiagMat
282+
@test_broken B - C isa Diagonal{Float64, Vector{Float64}}
283+
@test B - C Matrix(B) - Matrix(C)
284+
@test_broken C - B isa Diagonal{Float64, Vector{Float64}}
285+
@test C - B Matrix(C) - Matrix(B)
286+
end
259287
end

test/specialarrays.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ using StaticArrays
8484
@test Xt_invA_X(A, Y) isa Symmetric{Float64,<:SMatrix{10, 10, Float64}}
8585
@test Xt_invA_X(A, Y) Matrix(Y)' * (Matrix(A) \ Matrix(Y))
8686
end
87+
88+
# Subtraction falls back to the generic method in Base which is based on broadcasting
89+
@test Base.broadcastable(PDS) == PDS.mat
90+
@test Base.broadcastable(D) == Diagonal(D.diag)
91+
for A in (PDS, D), B in (PDS, D)
92+
@test A - B isa SMatrix{4, 4, Float64}
93+
@test A - B Matrix(A) - Matrix(B)
94+
end
95+
96+
# ScalMat does not behave nicely with broadcasting currently
97+
for A in (PDS, D)
98+
@test_broken A - E isa SMatrix{4, 4, Float64}
99+
@test_broken E - A isa SMatrix{4, 4, Float64}
100+
@test A - E Matrix(A) - Matrix(E)
101+
@test E - A Matrix(E) - Matrix(A)
102+
end
87103
end
88104

89105
@testset "BandedMatrices" begin

0 commit comments

Comments
 (0)