Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

properly handle colon in offset reshape #228

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/OffsetArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ end

export OffsetArray, OffsetMatrix, OffsetVector

include("axes.jl")
include("utils.jl")
include("origin.jl")

# Technically we know the length of CartesianIndices but we need to convert it first, so here we
# don't put it in OffsetAxisKnownLength.
const OffsetAxisKnownLength = Union{Integer, AbstractUnitRange}
const OffsetAxis = Union{OffsetAxisKnownLength, Colon}
const ArrayInitializer = Union{UndefInitializer, Missing, Nothing}

include("axes.jl")
include("utils.jl")
include("origin.jl")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to let utils.jl know the OffsetAxisKnownLength and OffsetAxis.

## OffsetArray
"""
OffsetArray(A, indices...)
Expand Down Expand Up @@ -280,13 +280,15 @@ end
Base.reshape(A::AbstractArray, inds::OffsetAxis...) = reshape(A, inds)
function Base.reshape(A::AbstractArray, inds::Tuple{OffsetAxis,Vararg{OffsetAxis}})
AR = reshape(A, map(_indexlength, inds))
return OffsetArray(AR, map(_offset, axes(AR), inds))
return OffsetArray(AR, _offset_reshape_uncolon(A, inds))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is equivalent to

    return OffsetArray(AR, map(_offset, axes(AR), _offset_reshape_uncolon(A, inds)))

By removing the map part, it calls another constructor.

# convert ranges to offsets
@eval @inline function $FT(A::AbstractArray, inds::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}})
_checkindices(A, inds, "indices")
# Performance gain by wrapping the error in a function: see https://github.com/JuliaLang/julia/issues/37558
throw_dimerr(lA, lI) = throw(DimensionMismatch("supplied axes do not agree with the size of the array (got size $lA for the array and $lI for the indices"))
lA = size(A)
lI = map(length, inds)
lA == lI || throw_dimerr(lA, lI)
$FT(A, map(_offset, axes(A), inds))
end

end

# Reshaping OffsetArrays can "pop" the original OffsetArray wrapper and return
# an OffsetArray(reshape(...)) instead of an OffsetArray(reshape(OffsetArray(...)))
Base.reshape(A::OffsetArray, inds::Tuple{OffsetAxis,Vararg{OffsetAxis}}) =
OffsetArray(reshape(parent(A), map(_indexlength, inds)), map(_indexoffset, inds))
function Base.reshape(A::OffsetArray, inds::Tuple{OffsetAxis,Vararg{OffsetAxis}})
AR = reshape(parent(A), map(_indexlength, inds))
OffsetArray(AR, _offset_reshape_uncolon(A, inds))
end
# And for non-offset axes, we can just return a reshape of the parent directly
Base.reshape(A::OffsetArray, inds::Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}) = reshape(parent(A), inds)
Base.reshape(A::OffsetArray, inds::Dims) = reshape(parent(A), inds)
Expand Down
32 changes: 32 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,37 @@ _indexlength(i::Colon) = Colon()
_offset(axparent::AbstractUnitRange, ax::AbstractUnitRange) = first(ax) - first(axparent)
_offset(axparent::AbstractUnitRange, ax::Integer) = 1 - first(axparent)

_offset_reshape_uncolon(::AbstractArray, I::Tuple{OffsetAxisKnownLength,Vararg{OffsetAxisKnownLength}}) = I
function _offset_reshape_uncolon(A::AbstractArray, I::Tuple{OffsetAxis,Vararg{OffsetAxis}})
@noinline throw1(I) = throw(DimensionMismatch(string("new dimensions $(I) ",
"may have at most one omitted dimension specified by `Colon()`")))
@noinline throw2(A, I) = throw(DimensionMismatch(string("array size $(length(A)) ",
"must be divisible by the product of the new dimensions $I")))

pre = _before_colon(I...)
post = _after_colon(I...)
_any_colon(post...) && throw1(dims)
nprev = isempty(pre) ? 1 : mapreduce(_length, *, pre)
npost = isempty(post) ? 1 : mapreduce(_length, *, post)
sz, remainder = divrem(length(A), nprev * npost)
remainder == 0 || throw2(A, dims)

# preserve the offset information, the default offset beyond `ndims(A)` is 0
n = length(pre)
r = Base.OneTo(Int(sz))
Δ = n < ndims(A) ? _offset(axes(A, n+1), r) : 0
(pre..., IdOffsetRange(r, -Δ), post...)
end
@inline _any_colon() = false
@inline _any_colon(r::Colon, tail...) = true
@inline _any_colon(r::Any, tail...) = _any_colon(tail...)
@inline _before_colon(r::Any, tail...) = (r, _before_colon(tail...)...)
@inline _before_colon(r::Colon, tail...) = ()
@inline _after_colon(r::Any, tail...) = _after_colon(tail...)
@inline _after_colon(r::Colon, tail...) = tail
@inline _length(r::AbstractUnitRange) = length(r)
@inline _length(n::Int) = n

"""
OffsetArrays.AxisConversionStyle(typeof(indices))

Expand Down Expand Up @@ -59,6 +90,7 @@ AxisConversionStyle(::Type) = SingleRange()
AxisConversionStyle(::Type{<:CartesianIndices}) = TupleOfRanges()

_convertTupleAbstractUnitRange(x) = _convertTupleAbstractUnitRange(AxisConversionStyle(typeof(x)), x)
_convertTupleAbstractUnitRange(::SingleRange, x::Int) = (Base.OneTo(x), )
_convertTupleAbstractUnitRange(::SingleRange, x) = (convert(AbstractUnitRange{Int}, x),)
_convertTupleAbstractUnitRange(::TupleOfRanges, x) = convert(Tuple{Vararg{AbstractUnitRange{Int}}}, x)

Expand Down
118 changes: 118 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,124 @@ end
Arsc = reshape(A, :, 1)
Arsc[1,1] = 5
@test first(A) == 5

# reshape with one Colon for AbstractArrays
# If possible, offset is preserved for OffsetArrays
# we test the following cases:
# - (RangeOrInt, :)
# - (RangeOrInt, :, RangeOrInt)
# - (: RangeOrInt)
@testset "Colon" begin
A0 = [1 2 3; 4 5 6]
A = OffsetArray(A0, -1, -1)

# (Range, :)
B = reshape(A0, -10:-9, :)
@test B isa OffsetArray{Int,2}
@test parent(B) === A0
@test axes(B) == (-10:-9, 1:3)
B = reshape(A, -10:-9, :)
@test B isa OffsetArray{Int,2}
@test parent(B) === A0
@test axes(B) == (-10:-9, 0:2)

# (Range, Int, :)
B = reshape(A0, -10:-9, 1, :)
@test B isa OffsetArray{Int,3}
@test same_value(B, A0)
@test axes(B) == (-10:-9, 1:1, 1:3)
# The position of `:` exceeds `ndims(A)` thus the offset is assumed to be 0 there
B = reshape(A, -10:-9, 1, :)
@test B isa OffsetArray{Int,3}
@test same_value(B, A0)
@test axes(B) == (-10:-9, 1:1, 1:3)

# (Range, Range, :)
B = reshape(A0, -10:-9, 3, :)
@test B isa OffsetArray{Int,3}
@test same_value(B, A0)
@test axes(B) == (-10:-9, 1:3, 1:1)
# The position of `:` exceeds `ndims(A)` thus the offset is assumed to be 0 there
B = reshape(A, -10:-9, -1:1, :)
@test B isa OffsetArray{Int,3}
@test same_value(B, A0)
@test axes(B) == (-10:-9, -1:1, 1:1)

# (:, Int) uses Base implementaion
# skip it

# (:, Range)
B = reshape(A0, :, -10:-8)
@test B isa OffsetArray{Int,2}
@test parent(B) === A0
@test axes(B) == (1:2, -10:-8)
B = reshape(A, :, -10:-8)
@test B isa OffsetArray{Int,2}
@test parent(B) === A0
@test axes(B) == (0:1, -10:-8)

# (Range, :, Int)
B = reshape(A0, -10:-9, :, 1)
@test B isa OffsetArray{Int,3}
@test same_value(A0, B)
@test axes(B) == (-10:-9, 1:3, 1:1)
B = reshape(A0, -10:-9, :, 3)
@test B isa OffsetArray{Int,3}
@test same_value(A0, B)
@test axes(B) == (-10:-9, 1:1, 1:3)

B = reshape(A, -10:-9, :, 1)
@test B isa OffsetArray{Int,3}
@test same_value(A, B)
@test axes(B) == (-10:-9, 0:2, 1:1)
B = reshape(A, -10:-9, :, 3)
@test B isa OffsetArray{Int,3}
@test same_value(A, B)
@test axes(B) == (-10:-9, 0:0, 1:3)

# (Range, :, Range)
B = reshape(A0, -10:-9, :, -10:-10)
@test B isa OffsetArray{Int,3}
@test same_value(A0, B)
@test axes(B) == (-10:-9, 1:3, -10:-10)
B = reshape(A0, -10:-9, :, -10:-8)
@test B isa OffsetArray{Int,3}
@test same_value(A0, B)
@test axes(B) == (-10:-9, 1:1, -10:-8)

B = reshape(A, -10:-9, :, -10:-10)
@test B isa OffsetArray{Int,3}
@test same_value(A, B)
@test axes(B) == (-10:-9, 0:2, -10:-10)
B = reshape(A, -10:-9, :, -10:-8)
@test B isa OffsetArray{Int,3}
@test same_value(A, B)
@test axes(B) == (-10:-9, 0:0, -10:-8)

# (Int, :, Range)
B = reshape(A0, 1, :, -10:-9)
@test B isa OffsetArray{Int,3}
@test same_value(A0, B)
@test axes(B) == (1:1, 1:3, -10:-9)
B = reshape(A0, 3, :, -10:-9)
@test B isa OffsetArray{Int,3}
@test same_value(A0, B)
@test axes(B) == (1:3, 1:1, -10:-9)

B = reshape(A, 1, :, -10:-9)
@test B isa OffsetArray{Int,3}
@test same_value(A, B)
@test axes(B) == (1:1, 0:2, -10:-9)
B = reshape(A, 3, :, -10:-9)
@test B isa OffsetArray{Int,3}
@test same_value(A, B)
@test axes(B) == (1:3, 0:0, -10:-9)

@test_throws DimensionMismatch reshape(A0, 1:2, :, :)
@test_throws DimensionMismatch reshape(A0, 1:2, 2, :)
@test_throws DimensionMismatch reshape(A, 1:2, :, :)
@test_throws DimensionMismatch reshape(A, 1:2, 2, :)
end
end

@testset "Indexing with OffsetArray axes" begin
Expand Down