Skip to content

Commit 120c431

Browse files
jishnubKristofferC
authored andcommitted
Specialize one for the SizedArray test helper (#58209)
Since the size of the array is encoded in the type, we may define `one` on the type. This is useful in certain linear algebra contexts. (cherry picked from commit d9fafab)
1 parent effc637 commit 120c431

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

test/abstractarray.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,3 +1907,36 @@ end
19071907
@test r2[i] == z[j]
19081908
end
19091909
end
1910+
1911+
@testset "one" begin
1912+
@test one([1 2; 3 4]) == [1 0; 0 1]
1913+
@test one([1 2; 3 4]) isa Matrix{Int}
1914+
1915+
struct Mat <: AbstractMatrix{Int}
1916+
p::Matrix{Int}
1917+
end
1918+
Base.size(m::Mat) = size(m.p)
1919+
Base.IndexStyle(::Type{<:Mat}) = IndexLinear()
1920+
Base.getindex(m::Mat, i::Int) = m.p[i]
1921+
Base.setindex!(m::Mat, v, i::Int) = m.p[i] = v
1922+
Base.similar(::Mat, ::Type{Int}, size::NTuple{2,Int}) = Mat(Matrix{Int}(undef, size))
1923+
1924+
@test one(Mat([1 2; 3 4])) == Mat([1 0; 0 1])
1925+
@test one(Mat([1 2; 3 4])) isa Mat
1926+
1927+
@testset "SizedArray" begin
1928+
S = [1 2; 3 4]
1929+
A = SizedArrays.SizedArray{(2,2)}(S)
1930+
@test one(A) == one(typeof(A))
1931+
@test oneunit(A) == oneunit(typeof(A))
1932+
M = fill(A, 2, 2)
1933+
O = one(M)
1934+
for I in CartesianIndices(M)
1935+
if I[1] == I[2]
1936+
@test O[I] == one(S)
1937+
else
1938+
@test O[I] == zero(S)
1939+
end
1940+
end
1941+
end
1942+
end

test/testhelpers/SizedArrays.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ Base.size(a::SizedArray) = size(typeof(a))
2929
Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ
3030
Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
3131
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
32+
function Base.one(::Type{SizedMatrix{SZ,T,A}}) where {SZ,T,A}
33+
allequal(SZ) || throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
34+
D = diagm(fill(one(T), SZ[1]))
35+
SizedArray{SZ}(convert(A, D))
36+
end
37+
Base.parent(S::SizedArray) = S.data
3238
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
3339
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data
3440
function *(S1::SizedArray, S2::SizedArray)

0 commit comments

Comments
 (0)