Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ShapedIndex #199

Merged
merged 6 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.30"
version = "3.1.31"

[deps]
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Expand Down
8 changes: 5 additions & 3 deletions src/array_index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ end
Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1
Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count
Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count

## getindex
@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)]
@propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int)
@boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i))
Expand Down Expand Up @@ -274,11 +276,11 @@ end
ind.reflocalinds[p][_i] + ind.refcoords[p] - 1
end

@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex{N}) where {N}
return _strides2int(offsets(x), strides(x), Tuple(i)) + offset1(x)
@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex) where {N}
return _strides2int(offsets(x), strides(x), Tuple(i)) + static(1)
end
@generated function _strides2int(o::O, s::S, i::I) where {O,S,I}
N = known_length(I)
N = known_length(S)
out = :()
for i in 1:N
tmp = :(((getfield(i, $i) - getfield(o, $i)) * getfield(s, $i)))
Expand Down
169 changes: 79 additions & 90 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,6 @@ to_index(::IndexLinear, axis, arg::CartesianIndices{1}) = axes(arg, 1)
@propagate_inbounds function to_index(::IndexLinear, axis, arg::AbstractCartesianIndex{1})
return to_index(axis, first(Tuple(arg)))
end
function to_index(::IndexLinear, x, arg::AbstractCartesianIndex{N}) where {N}
inds = Tuple(arg)
o = offsets(x)
s = size(x)
return first(inds) + (static(1) - first(o)) + _subs2int(first(s), tail(s), tail(o), tail(inds))
end
@inline function _subs2int(stride, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, inds::Tuple{Any,Vararg})
i = ((first(inds) - first(o)) * stride)
return i + _subs2int(stride * first(s), tail(s), tail(o), tail(inds))
end
function _subs2int(stride, s::Tuple{Any}, o::Tuple{Any}, inds::Tuple{Any})
return (first(inds) - first(o)) * stride
end
# trailing inbounds can only be 1 or 1:1
_subs2int(stride, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0)

@propagate_inbounds function to_index(::IndexLinear, x, arg::Union{Array{Bool}, BitArray})
@boundscheck checkbounds(x, arg)
return LogicalIndex{Int}(arg)
Expand Down Expand Up @@ -215,13 +199,18 @@ end
@boundscheck checkbounds(x, arg)
return LogicalIndex{Int}(arg)
end
to_index(::IndexCartesian, x, i::Integer) = NDIndex(_int2subs(offsets(x), size(x), i - static(1)))
@inline function _int2subs(o::Tuple{Any,Vararg{Any}}, s::Tuple{Any,Vararg{Any}}, i)
len = first(s)
inext = div(i, len)
return (canonicalize(i - len * inext + first(o)), _int2subs(tail(o), tail(s), inext)...)

# TODO delete this once the layout interface is working
_array_index(::IndexLinear, a, i::CanonicalInt) = i
@inline function _array_index(::IndexStyle, a, i::CanonicalInt)
CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i]
end
_int2subs(o::Tuple{Any}, s::Tuple{Any}, i) = canonicalize(i + first(o))
_array_index(::IndexLinear, a, i::AbstractCartesianIndex{1}) = getfield(Tuple(i), 1)
@inline function _array_index(::IndexLinear, a, i::AbstractCartesianIndex)
N = ndims(a)
StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a))[i]
end
_array_index(::IndexStyle, a, i::AbstractCartesianIndex) = i

"""
unsafe_reconstruct(A, data; kwargs...)
Expand Down Expand Up @@ -326,54 +315,50 @@ another instance of `ArrayInterface.getindex` should only be done by overloading
Changing indexing based on a given argument from `args` should be done through,
[`to_index`](@ref), or [`to_axis`](@ref).
"""
@propagate_inbounds getindex(A, args...) = unsafe_get_index(A, to_indices(A, args))
@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args)...)
@propagate_inbounds function getindex(A; kwargs...)
return unsafe_get_index(A, to_indices(A, order_named_inds(dimnames(A), values(kwargs))))
return unsafe_getindex(A, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...)
end
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)

## unsafe_get_index ##
unsafe_get_index(A, i::Tuple{}) = unsafe_get_element(A, ())
unsafe_get_index(A, i::Tuple{CanonicalInt}) = unsafe_get_element(A, getfield(i, 1))
function unsafe_get_index(A, i::Tuple{CanonicalInt,Vararg{CanonicalInt}})
unsafe_get_element(A, NDIndex(i))
## unsafe_getindex ##
function unsafe_getindex(a::A) where {A}
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A,)))
return unsafe_getindex(parent(a))
end
unsafe_get_index(A, i::Tuple) = unsafe_get_collection(A, i)

#=
unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T

Returns an element of `A` at the indices `inds`. This method assumes all `inds`
have been checked for being in bounds. Any new array type using `ArrayInterface.getindex`
must define `unsafe_get_element(::NewArrayType, inds)`.
=#
unsafe_get_element(a::A, inds) where {A} = _unsafe_get_element(has_parent(A), a, inds)
_unsafe_get_element(::True, a, inds) = unsafe_get_element(parent(a), inds)
_unsafe_get_element(::False, a, inds) = @inbounds(parent(a)[inds])
_unsafe_get_element(::False, a::AbstractArray2, i) = unsafe_get_element_error(a, i)

## Array ##
unsafe_get_element(A::Array, ::Tuple{}) = Base.arrayref(false, A, 1)
unsafe_get_element(A::Array, i::Integer) = Base.arrayref(false, A, Int(i))
unsafe_get_element(A::Array, i::NDIndex) = unsafe_get_element(A, to_index(A, i))
function unsafe_getindex(a::A, i::CanonicalInt) where {A}
idx = _array_index(IndexStyle(A), a, i)
if idx === i
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i)))
return unsafe_getindex(parent(a), i)
else
return unsafe_getindex(a, idx)
end
end
function unsafe_getindex(a::A, i::AbstractCartesianIndex) where {A}
idx = _array_index(IndexStyle(A), a, i)
if idx === i
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i)))
return unsafe_getindex(parent(a), i)
else
return unsafe_getindex(a, idx)
end
end
function unsafe_getindex(a, i::CanonicalInt, ii::Vararg{CanonicalInt})
unsafe_getindex(a, NDIndex(i, ii...))
end
unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i)

## LinearIndices ##
unsafe_get_element(A::LinearIndices, i::Integer) = Int(i)
unsafe_get_element(A::LinearIndices, i::NDIndex) = unsafe_get_element(A, to_index(A, i))
unsafe_getindex(A::Array) = Base.arrayref(false, A, 1)
unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))

unsafe_get_element(A::CartesianIndices, i::NDIndex) = CartesianIndex(i)
unsafe_get_element(A::CartesianIndices, i::Integer) = unsafe_get_element(A, to_index(A, i))
unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i)

unsafe_get_element(A::ReshapedArray, i::Integer) = unsafe_get_element(parent(A), i)
function unsafe_get_element(A::ReshapedArray, i::NDIndex)
return unsafe_get_element(parent(A), to_index(IndexLinear(), A, i))
end
unsafe_getindex(A::CartesianIndices, i::AbstractCartesianIndex) = CartesianIndex(i)

unsafe_get_element(A::SubArray, i) = @inbounds(A[i])
function unsafe_get_element_error(@nospecialize(A), @nospecialize(i))
throw(MethodError(unsafe_get_element, (A, i)))
end
unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i])
unsafe_getindex(A::SubArray, i::AbstractCartesianIndex) = @inbounds(A[i])

# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
#=
Expand Down Expand Up @@ -402,7 +387,7 @@ function _generate_unsafe_get_index!_body(N::Int)
# the optimizer is not clever enough to split the union without it
Dy === nothing && return dest
(idx, state) = Dy
dest[idx] = unsafe_get_element(src, NDIndex(Base.Cartesian.@ntuple($N, j)))
dest[idx] = unsafe_getindex(src, NDIndex(Base.Cartesian.@ntuple($N, j)))
Dy = iterate(D, state)
end
return dest
Expand Down Expand Up @@ -441,45 +426,49 @@ Store the given values at the given key or index within a collection.
"""
@propagate_inbounds function setindex!(A, val, args...)
if can_setindex(A)
return unsafe_set_index!(A, val, to_indices(A, args))
return unsafe_setindex!(A, val, to_indices(A, args)...)
else
error("Instance of type $(typeof(A)) are not mutable and cannot change elements after construction.")
end
end
@propagate_inbounds function setindex!(A, val; kwargs...)
return unsafe_set_index!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs))))
return unsafe_setindex!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...)
end

## unsafe_get_index ##
unsafe_set_index!(A, v, i::Tuple{}) = unsafe_set_element!(A, v, ())
unsafe_set_index!(A, v, i::Tuple{CanonicalInt}) = unsafe_set_element!(A, v, getfield(i, 1))
function unsafe_set_index!(A, v, i::Tuple{CanonicalInt,Vararg{CanonicalInt}})
unsafe_set_element!(A, v, NDIndex(i))
## unsafe_setindex! ##
function unsafe_setindex!(a::A, v) where {A}
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v)))
return unsafe_setindex!(parent(a), v)
end
unsafe_set_index!(A, v, i::Tuple) = unsafe_set_collection!(A, v, i)

#=
unsafe_set_element!(A, val, inds::Tuple)

Sets an element of `A` to `val` at indices `inds`. This method assumes all `inds`
have been checked for being in bounds. Any new array type using `ArrayInterface.setindex!`
must define `unsafe_set_element!(::NewArrayType, val, inds)`.
=#
unsafe_set_element!(a, val, inds) = _unsafe_set_element!(has_parent(a), a, val, inds)
_unsafe_set_element!(::True, a, val, inds) = unsafe_set_element!(parent(a), val, inds)
_unsafe_set_element!(::False, a, val, inds) = @inbounds(parent(a)[inds] = val)

function _unsafe_set_element!(::False, a::AbstractArray2, val, inds)
unsafe_set_element_error(a, val, inds)
function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A}
idx = _array_index(IndexStyle(A), a, i)
if idx === i
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i)))
return unsafe_setindex!(parent(a), v, i)
else
return unsafe_setindex!(a, v, idx)
end
end
unsafe_set_element_error(A, v, i) = throw(MethodError(unsafe_set_element!, (A, v, i)))

function unsafe_set_element!(A::Array{T}, val, ::Tuple{}) where {T}
Base.arrayset(false, A, convert(T, val)::T, 1)
function unsafe_setindex!(a::A, v, i::AbstractCartesianIndex) where {A}
idx = _array_index(IndexStyle(A), a, i)
if idx === i
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i)))
return unsafe_setindex!(parent(a), v, i)
else
return unsafe_setindex!(a, v, idx)
end
end
function unsafe_setindex!(a, v, i::CanonicalInt, ii::Vararg{CanonicalInt})
unsafe_setindex!(a, v, NDIndex(i, ii...))
end
function unsafe_set_element!(A::Array{T}, val, i::Integer) where {T}
return Base.arrayset(false, A, convert(T, val)::T, Int(i))
function unsafe_setindex!(A::Array{T}, v) where {T}
Base.arrayset(false, A, convert(T, v)::T, 1)
end
function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T}
return Base.arrayset(false, A, convert(T, v)::T, Int(i))
end

unsafe_setindex!(a, v, i::Vararg{Any}) = unsafe_set_collection!(a, v, i)

# This is based on Base._unsafe_setindex!.
#=
Expand All @@ -501,7 +490,7 @@ function _generate_unsafe_setindex!_body(N::Int)
# the optimizer that it does not need to emit error paths
Xy === nothing && break
(val, state) = Xy
unsafe_set_element!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i)))
unsafe_setindex!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i)))
Xy = iterate(x′, state)
end
A
Expand Down
9 changes: 4 additions & 5 deletions src/size.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ end

size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices)
_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim))

@inline size(B::VecAdjTrans) = (One(), length(parent(B)))
@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B))
@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A}
Expand Down Expand Up @@ -80,15 +79,15 @@ compile time. If a dimension does not have a known size along a dimension then `
returned in its position.
"""
known_size(x) = known_size(typeof(x))
known_size(::Type{T}) where {T} = eachop(known_size, nstatic(Val(ndims(T))), T)

known_size(::Type{T}) where {T} = eachop(_known_size, nstatic(Val(ndims(T))), axes_types(T))
_known_size(::Type{T}, dim::StaticInt) where {T} = known_length(_get_tuple(T, dim))
@inline known_size(x, dim) = known_size(typeof(x), dim)
@inline known_size(::Type{T}, dim) where {T} = known_size(T, to_dims(T, dim))
@inline function known_size(::Type{T}, dim::Integer) where {T}
@inline function known_size(::Type{T}, dim::CanonicalInt) where {T}
if ndims(T) < dim
return 1
else
return known_length(axes_types(T, dim))
return known_size(T)[dim]
end
end

2 changes: 1 addition & 1 deletion src/stridelayout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ Returns offsets of indices with respect to 0. If values are known at compile tim
it should return them as `Static` numbers.
For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`.
"""
offsets(x::StrideIndex) = getfield(x, :offsets)
@inline offsets(x, i) = static_first(indices(x, i))
offsets(::Tuple) = (One(),)
offsets(x::StrideIndex) = getfield(x, :offsets)
offsets(x) = eachop(_offsets, nstatic(Val(ndims(x))), x)
function _offsets(x::X, dim::StaticInt{D}) where {X,D}
start = known_first(axes_types(X, dim))
Expand Down
9 changes: 5 additions & 4 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ end
@test @inferred(ArrayInterface.to_index(axis, CartesianIndices(()))) === CartesianIndices(())

x = LinearIndices((static(0):static(3),static(3):static(5),static(-2):static(0)));
@test @inferred(ArrayInterface.to_index(x, NDIndex((0, 3, -2)))) === 1
@test @inferred(ArrayInterface.to_index(x, NDIndex(static(0), static(3), static(-2)))) === static(1)

# @test @inferred(ArrayInterface.to_index(x, NDIndex((0, 3, -2)))) === 1
# @test @inferred(ArrayInterface.to_index(x, NDIndex(static(0), static(3), static(-2)))) === static(1)

@test_throws BoundsError ArrayInterface.to_index(axis, 4)
@test_throws BoundsError ArrayInterface.to_index(axis, 1:4)
Expand Down Expand Up @@ -125,8 +126,8 @@ end
#@test_throws ArgumentError Base._sub2ind((1:3,), 2)
#@test_throws ArgumentError Base._ind2sub((1:3,), 2)
x = Array{Int,2}(undef, (2, 2))
ArrayInterface.unsafe_set_index!(x, 1, (2, 2))
@test ArrayInterface.unsafe_get_index(x, (2, 2)) === 1
ArrayInterface.unsafe_setindex!(x, 1, 2, 2)
@test ArrayInterface.unsafe_getindex(x, 2, 2) === 1

# FIXME @test_throws MethodError ArrayInterface.unsafe_set_element!(x, 1, (:x, :x))
# FIXME @test_throws MethodError ArrayInterface.unsafe_get_element(x, (:x, :x))
Expand Down