From 2877cbcc72496e5daab2b14adfb0987674d144b3 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 16 May 2024 17:37:17 +0530 Subject: [PATCH] Disallow `Bool` indices in structured matrix indexing (#54475) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, ```julia julia> B = Bidiagonal(1:3, 1:2, :U) 3×3 Bidiagonal{Int64, UnitRange{Int64}}: 1 1 ⋅ ⋅ 2 2 ⋅ ⋅ 3 julia> B[3, true] 0 julia> Matrix(B)[3, true] ERROR: ArgumentError: invalid index: true of type Bool ``` This PR changes the behavior to ```julia julia> B[3, true] ERROR: ArgumentError: invalid index: true of type Bool ``` --- stdlib/LinearAlgebra/src/bidiag.jl | 2 +- stdlib/LinearAlgebra/src/hessenberg.jl | 2 +- stdlib/LinearAlgebra/src/symmetric.jl | 4 ++-- stdlib/LinearAlgebra/src/triangular.jl | 8 ++++---- stdlib/LinearAlgebra/src/tridiag.jl | 4 ++-- stdlib/LinearAlgebra/test/bidiag.jl | 7 +++++++ stdlib/LinearAlgebra/test/hessenberg.jl | 6 ++++++ stdlib/LinearAlgebra/test/symmetric.jl | 9 +++++++++ stdlib/LinearAlgebra/test/triangular.jl | 14 ++++++++++++++ stdlib/LinearAlgebra/test/tridiag.jl | 8 ++++++++ 10 files changed, 54 insertions(+), 10 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 53fbff5839630..90c91f473e066 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -156,7 +156,7 @@ end end end -@inline function getindex(A::Bidiagonal{T}, i::Integer, j::Integer) where T +@inline function getindex(A::Bidiagonal{T}, i::Int, j::Int) where T @boundscheck checkbounds(A, i, j) if i == j return @inbounds A.dv[i] diff --git a/stdlib/LinearAlgebra/src/hessenberg.jl b/stdlib/LinearAlgebra/src/hessenberg.jl index 2bb5573ee308d..1ab68c7d9dc70 100644 --- a/stdlib/LinearAlgebra/src/hessenberg.jl +++ b/stdlib/LinearAlgebra/src/hessenberg.jl @@ -87,7 +87,7 @@ end Base.isassigned(H::UpperHessenberg, i::Int, j::Int) = i <= j+1 ? isassigned(H.data, i, j) : true -Base.@propagate_inbounds getindex(H::UpperHessenberg{T}, i::Integer, j::Integer) where {T} = +Base.@propagate_inbounds getindex(H::UpperHessenberg{T}, i::Int, j::Int) where {T} = i <= j+1 ? convert(T, H.data[i,j]) : zero(T) Base.@propagate_inbounds function setindex!(A::UpperHessenberg, x, i::Integer, j::Integer) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 4c0cdb5fe787c..5154e6696f909 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -233,7 +233,7 @@ axes(A::HermOrSym) = axes(A.data) end end -@inline function getindex(A::Symmetric, i::Integer, j::Integer) +@inline function getindex(A::Symmetric, i::Int, j::Int) @boundscheck checkbounds(A, i, j) @inbounds if i == j return symmetric(A.data[i, j], sym_uplo(A.uplo))::symmetric_type(eltype(A.data)) @@ -243,7 +243,7 @@ end return transpose(A.data[j, i]) end end -@inline function getindex(A::Hermitian, i::Integer, j::Integer) +@inline function getindex(A::Hermitian, i::Int, j::Int) @boundscheck checkbounds(A, i, j) @inbounds if i == j return hermitian(A.data[i, j], sym_uplo(A.uplo))::hermitian_type(eltype(A.data)) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 469411eb77d40..153472040f6e4 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -269,13 +269,13 @@ Base.isstored(A::UnitUpperTriangular, i::Int, j::Int) = Base.isstored(A::UpperTriangular, i::Int, j::Int) = i <= j ? Base.isstored(A.data, i, j) : false -@propagate_inbounds getindex(A::UnitLowerTriangular{T}, i::Integer, j::Integer) where {T} = +@propagate_inbounds getindex(A::UnitLowerTriangular{T}, i::Int, j::Int) where {T} = i > j ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T)) -@propagate_inbounds getindex(A::LowerTriangular, i::Integer, j::Integer) = +@propagate_inbounds getindex(A::LowerTriangular, i::Int, j::Int) = i >= j ? A.data[i,j] : _zero(A.data,j,i) -@propagate_inbounds getindex(A::UnitUpperTriangular{T}, i::Integer, j::Integer) where {T} = +@propagate_inbounds getindex(A::UnitUpperTriangular{T}, i::Int, j::Int) where {T} = i < j ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T)) -@propagate_inbounds getindex(A::UpperTriangular, i::Integer, j::Integer) = +@propagate_inbounds getindex(A::UpperTriangular, i::Int, j::Int) = i <= j ? A.data[i,j] : _zero(A.data,j,i) @propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer) diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index ffdb56ba2dd77..b60ae6c82b7d8 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -445,7 +445,7 @@ end end end -@inline function getindex(A::SymTridiagonal{T}, i::Integer, j::Integer) where T +@inline function getindex(A::SymTridiagonal{T}, i::Int, j::Int) where T @boundscheck checkbounds(A, i, j) if i == j return symmetric((@inbounds A.dv[i]), :U)::symmetric_type(eltype(A.dv)) @@ -680,7 +680,7 @@ end end end -@inline function getindex(A::Tridiagonal{T}, i::Integer, j::Integer) where T +@inline function getindex(A::Tridiagonal{T}, i::Int, j::Int) where T @boundscheck checkbounds(A, i, j) if i == j return @inbounds A.d[i] diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index cf6c0845194d2..25338d641c58f 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -911,4 +911,11 @@ end @test M == B end +@testset "getindex with Integers" begin + dv, ev = 1:4, 1:3 + B = Bidiagonal(dv, ev, :U) + @test_throws "invalid index" B[3, true] + @test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)] +end + end # module TestBidiagonal diff --git a/stdlib/LinearAlgebra/test/hessenberg.jl b/stdlib/LinearAlgebra/test/hessenberg.jl index 73786682480cb..411d761d11e4d 100644 --- a/stdlib/LinearAlgebra/test/hessenberg.jl +++ b/stdlib/LinearAlgebra/test/hessenberg.jl @@ -258,5 +258,11 @@ end @test copyto!(B, A) == A2 end +@testset "getindex with Integers" begin + M = reshape(1:9, 3, 3) + S = UpperHessenberg(M) + @test_throws "invalid index" S[3, true] + @test S[1,2] == S[Int8(1),UInt16(2)] == S[big(1), Int16(2)] +end end # module TestHessenberg diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 218171cab2192..683346d5a2619 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -1062,4 +1062,13 @@ end end end +@testset "getindex with Integers" begin + M = reshape(1:4,2,2) + for ST in (Symmetric, Hermitian) + S = ST(M) + @test_throws "invalid index" S[true, true] + @test S[1,2] == S[Int8(1),UInt16(2)] == S[big(1), Int16(2)] + end +end + end # module TestSymmetric diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index 4ec6b37ea2a37..c793c5a3e9924 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -1070,4 +1070,18 @@ end end end +@testset "getindex with Integers" begin + M = reshape(1:4,2,2) + for Ttype in (UpperTriangular, UnitUpperTriangular) + T = Ttype(M) + @test_throws "invalid index" T[2, true] + @test T[1,2] == T[Int8(1),UInt16(2)] == T[big(1), Int16(2)] + end + for Ttype in (LowerTriangular, UnitLowerTriangular) + T = Ttype(M) + @test_throws "invalid index" T[true, 2] + @test T[2,1] == T[Int8(2),UInt16(1)] == T[big(2), Int16(1)] + end +end + end # module TestTriangular diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 70c84faf9884f..064e0464701ae 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -859,4 +859,12 @@ end @test axes(B) === (ax, ax) end +@testset "getindex with Integers" begin + dv, ev = 1:4, 1:3 + for S in (Tridiagonal(ev, dv, ev), SymTridiagonal(dv, ev)) + @test_throws "invalid index" S[3, true] + @test S[1,2] == S[Int8(1),UInt16(2)] == S[big(1), Int16(2)] + end +end + end # module TestTridiagonal