Skip to content

Commit

Permalink
ShapedIndex (#199)
Browse files Browse the repository at this point in the history
* Use `CartesianIndices` for conversion of `Int` -> CartesianIndex.
* The `CartesianIndex` -> `Int` conversion is managed by composing a
`StrideIndex`, where the strides are computed using `size_to_strides`,
    instead of the internal memory representation.
* Improve known_size: Previously an array that needed a unique method for `known_size` also
needed a unique one for `known_size(::Type{A}, dim)`. Now `known_size`
is called and then indexed, requring only one new method.
  • Loading branch information
Tokazama committed Sep 3, 2021
1 parent 669658f commit f33e147
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 104 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.30"
version = "3.1.31"

[deps]
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Expand Down
8 changes: 5 additions & 3 deletions src/array_index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ end
Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1
Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count
Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count

## getindex
@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)]
@propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int)
@boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i))
Expand Down Expand Up @@ -274,11 +276,11 @@ end
ind.reflocalinds[p][_i] + ind.refcoords[p] - 1
end

@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex{N}) where {N}
return _strides2int(offsets(x), strides(x), Tuple(i)) + offset1(x)
@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex) where {N}
return _strides2int(offsets(x), strides(x), Tuple(i)) + static(1)
end
@generated function _strides2int(o::O, s::S, i::I) where {O,S,I}
N = known_length(I)
N = known_length(S)
out = :()
for i in 1:N
tmp = :(((getfield(i, $i) - getfield(o, $i)) * getfield(s, $i)))
Expand Down
169 changes: 79 additions & 90 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,6 @@ to_index(::IndexLinear, axis, arg::CartesianIndices{1}) = axes(arg, 1)
@propagate_inbounds function to_index(::IndexLinear, axis, arg::AbstractCartesianIndex{1})
return to_index(axis, first(Tuple(arg)))
end
function to_index(::IndexLinear, x, arg::AbstractCartesianIndex{N}) where {N}
inds = Tuple(arg)
o = offsets(x)
s = size(x)
return first(inds) + (static(1) - first(o)) + _subs2int(first(s), tail(s), tail(o), tail(inds))
end
@inline function _subs2int(stride, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, inds::Tuple{Any,Vararg})
i = ((first(inds) - first(o)) * stride)
return i + _subs2int(stride * first(s), tail(s), tail(o), tail(inds))
end
function _subs2int(stride, s::Tuple{Any}, o::Tuple{Any}, inds::Tuple{Any})
return (first(inds) - first(o)) * stride
end
# trailing inbounds can only be 1 or 1:1
_subs2int(stride, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0)

@propagate_inbounds function to_index(::IndexLinear, x, arg::Union{Array{Bool}, BitArray})
@boundscheck checkbounds(x, arg)
return LogicalIndex{Int}(arg)
Expand Down Expand Up @@ -215,13 +199,18 @@ end
@boundscheck checkbounds(x, arg)
return LogicalIndex{Int}(arg)
end
to_index(::IndexCartesian, x, i::Integer) = NDIndex(_int2subs(offsets(x), size(x), i - static(1)))
@inline function _int2subs(o::Tuple{Any,Vararg{Any}}, s::Tuple{Any,Vararg{Any}}, i)
len = first(s)
inext = div(i, len)
return (canonicalize(i - len * inext + first(o)), _int2subs(tail(o), tail(s), inext)...)

# TODO delete this once the layout interface is working
_array_index(::IndexLinear, a, i::CanonicalInt) = i
@inline function _array_index(::IndexStyle, a, i::CanonicalInt)
CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i]
end
_int2subs(o::Tuple{Any}, s::Tuple{Any}, i) = canonicalize(i + first(o))
_array_index(::IndexLinear, a, i::AbstractCartesianIndex{1}) = getfield(Tuple(i), 1)
@inline function _array_index(::IndexLinear, a, i::AbstractCartesianIndex)
N = ndims(a)
StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a))[i]
end
_array_index(::IndexStyle, a, i::AbstractCartesianIndex) = i

"""
unsafe_reconstruct(A, data; kwargs...)
Expand Down Expand Up @@ -326,54 +315,50 @@ another instance of `ArrayInterface.getindex` should only be done by overloading
Changing indexing based on a given argument from `args` should be done through,
[`to_index`](@ref), or [`to_axis`](@ref).
"""
@propagate_inbounds getindex(A, args...) = unsafe_get_index(A, to_indices(A, args))
@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args)...)
@propagate_inbounds function getindex(A; kwargs...)
return unsafe_get_index(A, to_indices(A, order_named_inds(dimnames(A), values(kwargs))))
return unsafe_getindex(A, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...)
end
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)

## unsafe_get_index ##
unsafe_get_index(A, i::Tuple{}) = unsafe_get_element(A, ())
unsafe_get_index(A, i::Tuple{CanonicalInt}) = unsafe_get_element(A, getfield(i, 1))
function unsafe_get_index(A, i::Tuple{CanonicalInt,Vararg{CanonicalInt}})
unsafe_get_element(A, NDIndex(i))
## unsafe_getindex ##
function unsafe_getindex(a::A) where {A}
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A,)))
return unsafe_getindex(parent(a))
end
unsafe_get_index(A, i::Tuple) = unsafe_get_collection(A, i)

#=
unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T
Returns an element of `A` at the indices `inds`. This method assumes all `inds`
have been checked for being in bounds. Any new array type using `ArrayInterface.getindex`
must define `unsafe_get_element(::NewArrayType, inds)`.
=#
unsafe_get_element(a::A, inds) where {A} = _unsafe_get_element(has_parent(A), a, inds)
_unsafe_get_element(::True, a, inds) = unsafe_get_element(parent(a), inds)
_unsafe_get_element(::False, a, inds) = @inbounds(parent(a)[inds])
_unsafe_get_element(::False, a::AbstractArray2, i) = unsafe_get_element_error(a, i)

## Array ##
unsafe_get_element(A::Array, ::Tuple{}) = Base.arrayref(false, A, 1)
unsafe_get_element(A::Array, i::Integer) = Base.arrayref(false, A, Int(i))
unsafe_get_element(A::Array, i::NDIndex) = unsafe_get_element(A, to_index(A, i))
function unsafe_getindex(a::A, i::CanonicalInt) where {A}
idx = _array_index(IndexStyle(A), a, i)
if idx === i
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i)))
return unsafe_getindex(parent(a), i)
else
return unsafe_getindex(a, idx)
end
end
function unsafe_getindex(a::A, i::AbstractCartesianIndex) where {A}
idx = _array_index(IndexStyle(A), a, i)
if idx === i
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i)))
return unsafe_getindex(parent(a), i)
else
return unsafe_getindex(a, idx)
end
end
function unsafe_getindex(a, i::CanonicalInt, ii::Vararg{CanonicalInt})
unsafe_getindex(a, NDIndex(i, ii...))
end
unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i)

## LinearIndices ##
unsafe_get_element(A::LinearIndices, i::Integer) = Int(i)
unsafe_get_element(A::LinearIndices, i::NDIndex) = unsafe_get_element(A, to_index(A, i))
unsafe_getindex(A::Array) = Base.arrayref(false, A, 1)
unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))

unsafe_get_element(A::CartesianIndices, i::NDIndex) = CartesianIndex(i)
unsafe_get_element(A::CartesianIndices, i::Integer) = unsafe_get_element(A, to_index(A, i))
unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i)

unsafe_get_element(A::ReshapedArray, i::Integer) = unsafe_get_element(parent(A), i)
function unsafe_get_element(A::ReshapedArray, i::NDIndex)
return unsafe_get_element(parent(A), to_index(IndexLinear(), A, i))
end
unsafe_getindex(A::CartesianIndices, i::AbstractCartesianIndex) = CartesianIndex(i)

unsafe_get_element(A::SubArray, i) = @inbounds(A[i])
function unsafe_get_element_error(@nospecialize(A), @nospecialize(i))
throw(MethodError(unsafe_get_element, (A, i)))
end
unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i])
unsafe_getindex(A::SubArray, i::AbstractCartesianIndex) = @inbounds(A[i])

# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
#=
Expand Down Expand Up @@ -402,7 +387,7 @@ function _generate_unsafe_get_index!_body(N::Int)
# the optimizer is not clever enough to split the union without it
Dy === nothing && return dest
(idx, state) = Dy
dest[idx] = unsafe_get_element(src, NDIndex(Base.Cartesian.@ntuple($N, j)))
dest[idx] = unsafe_getindex(src, NDIndex(Base.Cartesian.@ntuple($N, j)))
Dy = iterate(D, state)
end
return dest
Expand Down Expand Up @@ -441,45 +426,49 @@ Store the given values at the given key or index within a collection.
"""
@propagate_inbounds function setindex!(A, val, args...)
if can_setindex(A)
return unsafe_set_index!(A, val, to_indices(A, args))
return unsafe_setindex!(A, val, to_indices(A, args)...)
else
error("Instance of type $(typeof(A)) are not mutable and cannot change elements after construction.")
end
end
@propagate_inbounds function setindex!(A, val; kwargs...)
return unsafe_set_index!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs))))
return unsafe_setindex!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...)
end

## unsafe_get_index ##
unsafe_set_index!(A, v, i::Tuple{}) = unsafe_set_element!(A, v, ())
unsafe_set_index!(A, v, i::Tuple{CanonicalInt}) = unsafe_set_element!(A, v, getfield(i, 1))
function unsafe_set_index!(A, v, i::Tuple{CanonicalInt,Vararg{CanonicalInt}})
unsafe_set_element!(A, v, NDIndex(i))
## unsafe_setindex! ##
function unsafe_setindex!(a::A, v) where {A}
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v)))
return unsafe_setindex!(parent(a), v)
end
unsafe_set_index!(A, v, i::Tuple) = unsafe_set_collection!(A, v, i)

#=
unsafe_set_element!(A, val, inds::Tuple)
Sets an element of `A` to `val` at indices `inds`. This method assumes all `inds`
have been checked for being in bounds. Any new array type using `ArrayInterface.setindex!`
must define `unsafe_set_element!(::NewArrayType, val, inds)`.
=#
unsafe_set_element!(a, val, inds) = _unsafe_set_element!(has_parent(a), a, val, inds)
_unsafe_set_element!(::True, a, val, inds) = unsafe_set_element!(parent(a), val, inds)
_unsafe_set_element!(::False, a, val, inds) = @inbounds(parent(a)[inds] = val)

function _unsafe_set_element!(::False, a::AbstractArray2, val, inds)
unsafe_set_element_error(a, val, inds)
function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A}
idx = _array_index(IndexStyle(A), a, i)
if idx === i
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i)))
return unsafe_setindex!(parent(a), v, i)
else
return unsafe_setindex!(a, v, idx)
end
end
unsafe_set_element_error(A, v, i) = throw(MethodError(unsafe_set_element!, (A, v, i)))

function unsafe_set_element!(A::Array{T}, val, ::Tuple{}) where {T}
Base.arrayset(false, A, convert(T, val)::T, 1)
function unsafe_setindex!(a::A, v, i::AbstractCartesianIndex) where {A}
idx = _array_index(IndexStyle(A), a, i)
if idx === i
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i)))
return unsafe_setindex!(parent(a), v, i)
else
return unsafe_setindex!(a, v, idx)
end
end
function unsafe_setindex!(a, v, i::CanonicalInt, ii::Vararg{CanonicalInt})
unsafe_setindex!(a, v, NDIndex(i, ii...))
end
function unsafe_set_element!(A::Array{T}, val, i::Integer) where {T}
return Base.arrayset(false, A, convert(T, val)::T, Int(i))
function unsafe_setindex!(A::Array{T}, v) where {T}
Base.arrayset(false, A, convert(T, v)::T, 1)
end
function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T}
return Base.arrayset(false, A, convert(T, v)::T, Int(i))
end

unsafe_setindex!(a, v, i::Vararg{Any}) = unsafe_set_collection!(a, v, i)

# This is based on Base._unsafe_setindex!.
#=
Expand All @@ -501,7 +490,7 @@ function _generate_unsafe_setindex!_body(N::Int)
# the optimizer that it does not need to emit error paths
Xy === nothing && break
(val, state) = Xy
unsafe_set_element!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i)))
unsafe_setindex!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i)))
Xy = iterate(x′, state)
end
A
Expand Down
9 changes: 4 additions & 5 deletions src/size.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ end

size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices)
_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim))

@inline size(B::VecAdjTrans) = (One(), length(parent(B)))
@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B))
@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A}
Expand Down Expand Up @@ -80,15 +79,15 @@ compile time. If a dimension does not have a known size along a dimension then `
returned in its position.
"""
known_size(x) = known_size(typeof(x))
known_size(::Type{T}) where {T} = eachop(known_size, nstatic(Val(ndims(T))), T)

known_size(::Type{T}) where {T} = eachop(_known_size, nstatic(Val(ndims(T))), axes_types(T))
_known_size(::Type{T}, dim::StaticInt) where {T} = known_length(_get_tuple(T, dim))
@inline known_size(x, dim) = known_size(typeof(x), dim)
@inline known_size(::Type{T}, dim) where {T} = known_size(T, to_dims(T, dim))
@inline function known_size(::Type{T}, dim::Integer) where {T}
@inline function known_size(::Type{T}, dim::CanonicalInt) where {T}
if ndims(T) < dim
return 1
else
return known_length(axes_types(T, dim))
return known_size(T)[dim]
end
end

2 changes: 1 addition & 1 deletion src/stridelayout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ Returns offsets of indices with respect to 0. If values are known at compile tim
it should return them as `Static` numbers.
For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`.
"""
offsets(x::StrideIndex) = getfield(x, :offsets)
@inline offsets(x, i) = static_first(indices(x, i))
offsets(::Tuple) = (One(),)
offsets(x::StrideIndex) = getfield(x, :offsets)
offsets(x) = eachop(_offsets, nstatic(Val(ndims(x))), x)
function _offsets(x::X, dim::StaticInt{D}) where {X,D}
start = known_first(axes_types(X, dim))
Expand Down
9 changes: 5 additions & 4 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ end
@test @inferred(ArrayInterface.to_index(axis, CartesianIndices(()))) === CartesianIndices(())

x = LinearIndices((static(0):static(3),static(3):static(5),static(-2):static(0)));
@test @inferred(ArrayInterface.to_index(x, NDIndex((0, 3, -2)))) === 1
@test @inferred(ArrayInterface.to_index(x, NDIndex(static(0), static(3), static(-2)))) === static(1)

# @test @inferred(ArrayInterface.to_index(x, NDIndex((0, 3, -2)))) === 1
# @test @inferred(ArrayInterface.to_index(x, NDIndex(static(0), static(3), static(-2)))) === static(1)

@test_throws BoundsError ArrayInterface.to_index(axis, 4)
@test_throws BoundsError ArrayInterface.to_index(axis, 1:4)
Expand Down Expand Up @@ -125,8 +126,8 @@ end
#@test_throws ArgumentError Base._sub2ind((1:3,), 2)
#@test_throws ArgumentError Base._ind2sub((1:3,), 2)
x = Array{Int,2}(undef, (2, 2))
ArrayInterface.unsafe_set_index!(x, 1, (2, 2))
@test ArrayInterface.unsafe_get_index(x, (2, 2)) === 1
ArrayInterface.unsafe_setindex!(x, 1, 2, 2)
@test ArrayInterface.unsafe_getindex(x, 2, 2) === 1

# FIXME @test_throws MethodError ArrayInterface.unsafe_set_element!(x, 1, (:x, :x))
# FIXME @test_throws MethodError ArrayInterface.unsafe_get_element(x, (:x, :x))
Expand Down

2 comments on commit f33e147

@Tokazama
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/44170

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.1.31 -m "<description of version>" f33e147e85bf72e23efa1bcff7a7a2edf729e1a2
git push origin v3.1.31

Please sign in to comment.