Skip to content

Commit

Permalink
Disallow Bool indices in structured matrix indexing (#54475)
Browse files Browse the repository at this point in the history
Currently,
```julia
julia> B = Bidiagonal(1:3, 1:2, :U)
3×3 Bidiagonal{Int64, UnitRange{Int64}}:
 1  1  ⋅
 ⋅  2  2
 ⋅  ⋅  3

julia> B[3, true]
0

julia> Matrix(B)[3, true]
ERROR: ArgumentError: invalid index: true of type Bool
```
This PR changes the behavior to
```julia
julia> B[3, true]
ERROR: ArgumentError: invalid index: true of type Bool
```
  • Loading branch information
jishnub authored May 16, 2024
1 parent e7a1def commit 2877cbc
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 10 deletions.
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ end
end
end

@inline function getindex(A::Bidiagonal{T}, i::Integer, j::Integer) where T
@inline function getindex(A::Bidiagonal{T}, i::Int, j::Int) where T
@boundscheck checkbounds(A, i, j)
if i == j
return @inbounds A.dv[i]
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ end
Base.isassigned(H::UpperHessenberg, i::Int, j::Int) =
i <= j+1 ? isassigned(H.data, i, j) : true

Base.@propagate_inbounds getindex(H::UpperHessenberg{T}, i::Integer, j::Integer) where {T} =
Base.@propagate_inbounds getindex(H::UpperHessenberg{T}, i::Int, j::Int) where {T} =
i <= j+1 ? convert(T, H.data[i,j]) : zero(T)

Base.@propagate_inbounds function setindex!(A::UpperHessenberg, x, i::Integer, j::Integer)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ axes(A::HermOrSym) = axes(A.data)
end
end

@inline function getindex(A::Symmetric, i::Integer, j::Integer)
@inline function getindex(A::Symmetric, i::Int, j::Int)
@boundscheck checkbounds(A, i, j)
@inbounds if i == j
return symmetric(A.data[i, j], sym_uplo(A.uplo))::symmetric_type(eltype(A.data))
Expand All @@ -243,7 +243,7 @@ end
return transpose(A.data[j, i])
end
end
@inline function getindex(A::Hermitian, i::Integer, j::Integer)
@inline function getindex(A::Hermitian, i::Int, j::Int)
@boundscheck checkbounds(A, i, j)
@inbounds if i == j
return hermitian(A.data[i, j], sym_uplo(A.uplo))::hermitian_type(eltype(A.data))
Expand Down
8 changes: 4 additions & 4 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,13 @@ Base.isstored(A::UnitUpperTriangular, i::Int, j::Int) =
Base.isstored(A::UpperTriangular, i::Int, j::Int) =
i <= j ? Base.isstored(A.data, i, j) : false

@propagate_inbounds getindex(A::UnitLowerTriangular{T}, i::Integer, j::Integer) where {T} =
@propagate_inbounds getindex(A::UnitLowerTriangular{T}, i::Int, j::Int) where {T} =
i > j ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
@propagate_inbounds getindex(A::LowerTriangular, i::Integer, j::Integer) =
@propagate_inbounds getindex(A::LowerTriangular, i::Int, j::Int) =
i >= j ? A.data[i,j] : _zero(A.data,j,i)
@propagate_inbounds getindex(A::UnitUpperTriangular{T}, i::Integer, j::Integer) where {T} =
@propagate_inbounds getindex(A::UnitUpperTriangular{T}, i::Int, j::Int) where {T} =
i < j ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
@propagate_inbounds getindex(A::UpperTriangular, i::Integer, j::Integer) =
@propagate_inbounds getindex(A::UpperTriangular, i::Int, j::Int) =
i <= j ? A.data[i,j] : _zero(A.data,j,i)

@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ end
end
end

@inline function getindex(A::SymTridiagonal{T}, i::Integer, j::Integer) where T
@inline function getindex(A::SymTridiagonal{T}, i::Int, j::Int) where T
@boundscheck checkbounds(A, i, j)
if i == j
return symmetric((@inbounds A.dv[i]), :U)::symmetric_type(eltype(A.dv))
Expand Down Expand Up @@ -680,7 +680,7 @@ end
end
end

@inline function getindex(A::Tridiagonal{T}, i::Integer, j::Integer) where T
@inline function getindex(A::Tridiagonal{T}, i::Int, j::Int) where T
@boundscheck checkbounds(A, i, j)
if i == j
return @inbounds A.d[i]
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -911,4 +911,11 @@ end
@test M == B
end

@testset "getindex with Integers" begin
dv, ev = 1:4, 1:3
B = Bidiagonal(dv, ev, :U)
@test_throws "invalid index" B[3, true]
@test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)]
end

end # module TestBidiagonal
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/test/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,5 +258,11 @@ end
@test copyto!(B, A) == A2
end

@testset "getindex with Integers" begin
M = reshape(1:9, 3, 3)
S = UpperHessenberg(M)
@test_throws "invalid index" S[3, true]
@test S[1,2] == S[Int8(1),UInt16(2)] == S[big(1), Int16(2)]
end

end # module TestHessenberg
9 changes: 9 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1062,4 +1062,13 @@ end
end
end

@testset "getindex with Integers" begin
M = reshape(1:4,2,2)
for ST in (Symmetric, Hermitian)
S = ST(M)
@test_throws "invalid index" S[true, true]
@test S[1,2] == S[Int8(1),UInt16(2)] == S[big(1), Int16(2)]
end
end

end # module TestSymmetric
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1070,4 +1070,18 @@ end
end
end

@testset "getindex with Integers" begin
M = reshape(1:4,2,2)
for Ttype in (UpperTriangular, UnitUpperTriangular)
T = Ttype(M)
@test_throws "invalid index" T[2, true]
@test T[1,2] == T[Int8(1),UInt16(2)] == T[big(1), Int16(2)]
end
for Ttype in (LowerTriangular, UnitLowerTriangular)
T = Ttype(M)
@test_throws "invalid index" T[true, 2]
@test T[2,1] == T[Int8(2),UInt16(1)] == T[big(2), Int16(1)]
end
end

end # module TestTriangular
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -859,4 +859,12 @@ end
@test axes(B) === (ax, ax)
end

@testset "getindex with Integers" begin
dv, ev = 1:4, 1:3
for S in (Tridiagonal(ev, dv, ev), SymTridiagonal(dv, ev))
@test_throws "invalid index" S[3, true]
@test S[1,2] == S[Int8(1),UInt16(2)] == S[big(1), Int16(2)]
end
end

end # module TestTridiagonal

0 comments on commit 2877cbc

Please sign in to comment.