Skip to content

Commit

Permalink
Make RowVector obey AbstractMatrix indexing interface (#23706)
Browse files Browse the repository at this point in the history
Fixes #23649

(cherry picked from commit b9aed45)
  • Loading branch information
andyferris authored and ararslan committed Nov 14, 2017
1 parent 67ea744 commit cca2e70
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
31 changes: 6 additions & 25 deletions base/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,13 @@ julia> conj(v)
IndexStyle(::RowVector) = IndexLinear()
IndexStyle(::Type{<:RowVector}) = IndexLinear()

@propagate_inbounds getindex(rowvec::RowVector, i) = transpose(rowvec.vec[i])
@propagate_inbounds setindex!(rowvec::RowVector, v, i) = setindex!(rowvec.vec, transpose(v), i)

# Cartesian indexing is distorted by getindex
# Furthermore, Cartesian indexes don't have to match shape, apparently!
@inline function getindex(rowvec::RowVector, i::CartesianIndex)
@boundscheck if !(i.I[1] == 1 && i.I[2] indices(rowvec.vec)[1] && check_tail_indices(i.I...))
throw(BoundsError(rowvec, i.I))
end
@inbounds return transpose(rowvec.vec[i.I[2]])
end
@inline function setindex!(rowvec::RowVector, v, i::CartesianIndex)
@boundscheck if !(i.I[1] == 1 && i.I[2] indices(rowvec.vec)[1] && check_tail_indices(i.I...))
throw(BoundsError(rowvec, i.I))
end
@inbounds rowvec.vec[i.I[2]] = transpose(v)
end

@propagate_inbounds getindex(rowvec::RowVector, ::CartesianIndex{0}) = getindex(rowvec)
@propagate_inbounds getindex(rowvec::RowVector, i::CartesianIndex{1}) = getindex(rowvec, i.I[1])

@propagate_inbounds setindex!(rowvec::RowVector, v, ::CartesianIndex{0}) = setindex!(rowvec, v)
@propagate_inbounds setindex!(rowvec::RowVector, v, i::CartesianIndex{1}) = setindex!(rowvec, v, i.I[1])
@propagate_inbounds getindex(rowvec::RowVector, i::Int) = transpose(rowvec.vec[i])
@propagate_inbounds setindex!(rowvec::RowVector, v, i::Int) = setindex!(rowvec.vec, transpose(v), i)

@inline check_tail_indices(i1, i2) = true
@inline check_tail_indices(i1, i2, i3, is...) = i3 == 1 ? check_tail_indices(i1, i2, is...) : false
# Keep a RowVector where appropriate
@propagate_inbounds getindex(rowvec::RowVector, ::Colon, i::Int) = transpose.(rowvec.vec[i:i])
@propagate_inbounds getindex(rowvec::RowVector, ::Colon, inds::AbstractArray{Int}) = RowVector(rowvec.vec[inds])
@propagate_inbounds getindex(rowvec::RowVector, ::Colon, ::Colon) = RowVector(rowvec.vec[:])

# helper function for below
@inline to_vec(rowvec::RowVector) = map(transpose, transpose(rowvec))
Expand Down
4 changes: 4 additions & 0 deletions test/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,7 @@ end
end
end
end

@testset "Diagonal of a RowVector (#23649)" begin
@test Diagonal([1,2,3].') == Diagonal([1 2 3])
end
20 changes: 20 additions & 0 deletions test/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,23 @@ end
@testset "ambiguity between * methods with RowVectors and ConjRowVectors (#20971)" begin
@test RowVector(ConjArray(ones(4))) * ones(4) == 4
end

@testset "setindex!/getindex" begin
v = [2, 3, 4]
rv = v.'
@test_throws BoundsError setindex!(rv, 5, CartesianIndex((5, 4, 3)))
rv[CartesianIndex((1, 1, 1))] = 5
@test_throws BoundsError getindex(rv, CartesianIndex((5, 4, 3)))
@test rv[1] == 5

@test rv[:, 2]::Vector == [v[2]]
@test rv[:, 2:3]::RowVector == v[2:3].'
@test rv[:, :]::RowVector == rv

v = [1]
rv = v.'
rv[CartesianIndex()] = 2
@test rv[CartesianIndex()] == 2
rv[CartesianIndex(1)] = 1
@test rv[CartesianIndex(1)] == 1
end

0 comments on commit cca2e70

Please sign in to comment.