Skip to content

Commit

Permalink
Make RowVector obey AbstractMatrix indexing interface (#23706)
Browse files Browse the repository at this point in the history
Fixes #23649
  • Loading branch information
andyferris authored and andreasnoack committed Sep 15, 2017
1 parent 3477b88 commit b9aed45
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 25 deletions.
31 changes: 6 additions & 25 deletions base/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,32 +113,13 @@ julia> conj(v)
IndexStyle(::RowVector) = IndexLinear()
IndexStyle(::Type{<:RowVector}) = IndexLinear()

@propagate_inbounds getindex(rowvec::RowVector, i) = transpose(rowvec.vec[i])
@propagate_inbounds setindex!(rowvec::RowVector, v, i) = setindex!(rowvec.vec, transpose(v), i)

# Cartesian indexing is distorted by getindex
# Furthermore, Cartesian indexes don't have to match shape, apparently!
@inline function getindex(rowvec::RowVector, i::CartesianIndex)
@boundscheck if !(i.I[1] == 1 && i.I[2] indices(rowvec.vec)[1] && check_tail_indices(i.I...))
throw(BoundsError(rowvec, i.I))
end
@inbounds return transpose(rowvec.vec[i.I[2]])
end
@inline function setindex!(rowvec::RowVector, v, i::CartesianIndex)
@boundscheck if !(i.I[1] == 1 && i.I[2] indices(rowvec.vec)[1] && check_tail_indices(i.I...))
throw(BoundsError(rowvec, i.I))
end
@inbounds rowvec.vec[i.I[2]] = transpose(v)
end

@propagate_inbounds getindex(rowvec::RowVector, ::CartesianIndex{0}) = getindex(rowvec)
@propagate_inbounds getindex(rowvec::RowVector, i::CartesianIndex{1}) = getindex(rowvec, i.I[1])

@propagate_inbounds setindex!(rowvec::RowVector, v, ::CartesianIndex{0}) = setindex!(rowvec, v)
@propagate_inbounds setindex!(rowvec::RowVector, v, i::CartesianIndex{1}) = setindex!(rowvec, v, i.I[1])
@propagate_inbounds getindex(rowvec::RowVector, i::Int) = transpose(rowvec.vec[i])
@propagate_inbounds setindex!(rowvec::RowVector, v, i::Int) = setindex!(rowvec.vec, transpose(v), i)

@inline check_tail_indices(i1, i2) = true
@inline check_tail_indices(i1, i2, i3, is...) = i3 == 1 ? check_tail_indices(i1, i2, is...) : false
# Keep a RowVector where appropriate
@propagate_inbounds getindex(rowvec::RowVector, ::Colon, i::Int) = transpose.(rowvec.vec[i:i])
@propagate_inbounds getindex(rowvec::RowVector, ::Colon, inds::AbstractArray{Int}) = RowVector(rowvec.vec[inds])
@propagate_inbounds getindex(rowvec::RowVector, ::Colon, ::Colon) = RowVector(rowvec.vec[:])

# helper function for below
@inline to_vec(rowvec::RowVector) = map(transpose, transpose(rowvec))
Expand Down
4 changes: 4 additions & 0 deletions test/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,7 @@ end
end
end
end

@testset "Diagonal of a RowVector (#23649)" begin
@test Diagonal([1,2,3].') == Diagonal([1 2 3])
end
4 changes: 4 additions & 0 deletions test/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ end
@test_throws BoundsError getindex(rv, CartesianIndex((5, 4, 3)))
@test rv[1] == 5

@test rv[:, 2]::Vector == [v[2]]
@test rv[:, 2:3]::RowVector == v[2:3].'
@test rv[:, :]::RowVector == rv

v = [1]
rv = v.'
rv[CartesianIndex()] = 2
Expand Down

0 comments on commit b9aed45

Please sign in to comment.