From 4839281c455d507f210cb1d4de96a18407b138c1 Mon Sep 17 00:00:00 2001 From: Andy Ferris Date: Thu, 14 Sep 2017 14:00:58 +1000 Subject: [PATCH] Make RowVector obey AbstractMatrix indexing interface Fixes #23649 --- base/linalg/rowvector.jl | 31 ++++++------------------------- test/linalg/diagonal.jl | 4 ++++ test/linalg/rowvector.jl | 4 ++++ 3 files changed, 14 insertions(+), 25 deletions(-) diff --git a/base/linalg/rowvector.jl b/base/linalg/rowvector.jl index 37af261e3ea35..44b2a5d555646 100644 --- a/base/linalg/rowvector.jl +++ b/base/linalg/rowvector.jl @@ -112,32 +112,13 @@ julia> conj(v) IndexStyle(::RowVector) = IndexLinear() IndexStyle(::Type{<:RowVector}) = IndexLinear() -@propagate_inbounds getindex(rowvec::RowVector, i) = transpose(rowvec.vec[i]) -@propagate_inbounds setindex!(rowvec::RowVector, v, i) = setindex!(rowvec.vec, transpose(v), i) - -# Cartesian indexing is distorted by getindex -# Furthermore, Cartesian indexes don't have to match shape, apparently! -@inline function getindex(rowvec::RowVector, i::CartesianIndex) - @boundscheck if !(i.I[1] == 1 && i.I[2] ∈ indices(rowvec.vec)[1] && check_tail_indices(i.I...)) - throw(BoundsError(rowvec, i.I)) - end - @inbounds return transpose(rowvec.vec[i.I[2]]) -end -@inline function setindex!(rowvec::RowVector, v, i::CartesianIndex) - @boundscheck if !(i.I[1] == 1 && i.I[2] ∈ indices(rowvec.vec)[1] && check_tail_indices(i.I...)) - throw(BoundsError(rowvec, i.I)) - end - @inbounds rowvec.vec[i.I[2]] = transpose(v) -end - -@propagate_inbounds getindex(rowvec::RowVector, ::CartesianIndex{0}) = getindex(rowvec) -@propagate_inbounds getindex(rowvec::RowVector, i::CartesianIndex{1}) = getindex(rowvec, i.I[1]) - -@propagate_inbounds setindex!(rowvec::RowVector, v, ::CartesianIndex{0}) = setindex!(rowvec, v) -@propagate_inbounds setindex!(rowvec::RowVector, v, i::CartesianIndex{1}) = setindex!(rowvec, v, i.I[1]) +@propagate_inbounds getindex(rowvec::RowVector, i::Int) = transpose(rowvec.vec[i]) +@propagate_inbounds setindex!(rowvec::RowVector, v, i::Int) = setindex!(rowvec.vec, transpose(v), i) -@inline check_tail_indices(i1, i2) = true -@inline check_tail_indices(i1, i2, i3, is...) = i3 == 1 ? check_tail_indices(i1, i2, is...) : false +# Keep a RowVector where appropriate +@propagate_inbounds getindex(rowvec::RowVector, ::Colon, i::Int) = transpose.(rowvec.vec[i:i]) +@propagate_inbounds getindex(rowvec::RowVector, ::Colon, inds::AbstractArray{Int}) = RowVector(rowvec.vec[inds]) +@propagate_inbounds getindex(rowvec::RowVector, ::Colon, ::Colon) = RowVector(rowvec.vec[:]) # helper function for below @inline to_vec(rowvec::RowVector) = map(transpose, transpose(rowvec)) diff --git a/test/linalg/diagonal.jl b/test/linalg/diagonal.jl index e2696cada43cb..52beafc957d5f 100644 --- a/test/linalg/diagonal.jl +++ b/test/linalg/diagonal.jl @@ -403,3 +403,7 @@ end end end end + +@testset "Diagonal of a RowVector (#23649)" begin + @test Diagonal([1,2,3].') == Diagonal([1 2 3]) +end \ No newline at end of file diff --git a/test/linalg/rowvector.jl b/test/linalg/rowvector.jl index 832510c243ba1..dfa1e8411ab54 100644 --- a/test/linalg/rowvector.jl +++ b/test/linalg/rowvector.jl @@ -285,6 +285,10 @@ end @test_throws BoundsError getindex(rv, CartesianIndex((5, 4, 3))) @test rv[1] == 5 + @test rv[:, 2]::Vector == [v[2]] + @test rv[:, 2:3]::RowVector == v[2:3].' + @test rv[:, :]::RowVector == rv + v = [1] rv = v.' rv[CartesianIndex()] = 2