Skip to content

Commit

Permalink
Range indexing: error with scalar bool index like all other arrays (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bkamins authored Mar 8, 2021
1 parent 5fab42a commit 6fb3558
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 33 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ Standard library changes
* `escape_string` can now receive a collection of characters in the keyword
`keep` that are to be kept as they are. ([#38597]).
* `getindex` can now be used on `NamedTuple`s with multiple values ([#38878])
* Subtypes of `AbstractRange` now correctly follow the general array indexing
behavior when indexed by `Bool`s, erroring for scalar `Bool`s and treating
arrays (including ranges) of `Bool` as an logical index ([#31829])
* `keys(::RegexMatch)` is now defined to return the capture's keys, by name if named, or by index if not ([#37299]).
* `keys(::Generator)` is now defined to return the iterator's keys ([#34678])
* `RegexMatch` now iterate to give their captures. ([#34355]).
Expand Down
132 changes: 111 additions & 21 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,19 @@ be 1.
"""
struct OneTo{T<:Integer} <: AbstractUnitRange{T}
stop::T
OneTo{T}(stop) where {T<:Integer} = new(max(zero(T), stop))
function OneTo{T}(stop) where {T<:Integer}
throwbool(r) = (@_noinline_meta; throw(ArgumentError("invalid index: $r of type Bool")))
T === Bool && throwbool(stop)
return new(max(zero(T), stop))
end

function OneTo{T}(r::AbstractRange) where {T<:Integer}
throwstart(r) = (@_noinline_meta; throw(ArgumentError("first element must be 1, got $(first(r))")))
throwstep(r) = (@_noinline_meta; throw(ArgumentError("step must be 1, got $(step(r))")))
throwbool(r) = (@_noinline_meta; throw(ArgumentError("invalid index: $r of type Bool")))
first(r) == 1 || throwstart(r)
step(r) == 1 || throwstep(r)
T === Bool && throwbool(r)
return new(max(zero(T), last(r)))
end
end
Expand Down Expand Up @@ -748,6 +755,7 @@ _in_unit_range(v::UnitRange, val, i::Integer) = i > 0 && val <= v.stop && val >=

function getindex(v::UnitRange{T}, i::Integer) where T
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
val = convert(T, v.start + (i - 1))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val
Expand All @@ -758,19 +766,22 @@ const OverflowSafe = Union{Bool,Int8,Int16,Int32,Int64,Int128,

function getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe}
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
val = v.start + (i - 1)
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val % T
end

function getindex(v::OneTo{T}, i::Integer) where T
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck ((i > 0) & (i <= v.stop)) || throw_boundserror(v, i)
convert(T, i)
end

function getindex(v::AbstractRange{T}, i::Integer) where T
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
ret = convert(T, first(v) + (i - 1)*step_hp(v))
ok = ifelse(step(v) > zero(step(v)),
(ret <= last(v)) & (ret >= first(v)),
Expand All @@ -781,22 +792,26 @@ end

function getindex(r::Union{StepRangeLen,LinRange}, i::Integer)
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck checkbounds(r, i)
unsafe_getindex(r, i)
end

# This is separate to make it useful even when running with --check-bounds=yes
function unsafe_getindex(r::StepRangeLen{T}, i::Integer) where T
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = i - r.offset
T(r.ref + u*r.step)
end

function _getindex_hiprec(r::StepRangeLen, i::Integer) # without rounding by T
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = i - r.offset
r.ref + u*r.step
end

function unsafe_getindex(r::LinRange, i::Integer)
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
lerpi(i-1, r.lendiv, r.start, r.stop)
end

Expand All @@ -808,12 +823,27 @@ end

getindex(r::AbstractRange, ::Colon) = copy(r)

function getindex(r::AbstractUnitRange, s::AbstractUnitRange{<:Integer})
function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
f = first(r)
st = oftype(f, f + first(s)-1)
range(st, length=length(s))

if T === Bool
if length(s) == 0
return r
elseif length(s) == 1
if first(s)
return r
else
return range(r[1], length=0)
end
else # length(s) == 2
return range(r[2], length=1)
end
else
f = first(r)
st = oftype(f, f + first(s)-1)
return range(st, length=length(s))
end
end

function getindex(r::OneTo{T}, s::OneTo) where T
Expand All @@ -822,36 +852,96 @@ function getindex(r::OneTo{T}, s::OneTo) where T
OneTo(T(s.stop))
end

function getindex(r::AbstractUnitRange, s::StepRange{<:Integer})
function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
st = oftype(first(r), first(r) + s.start-1)
range(st, step=step(s), length=length(s))

if T === Bool
if length(s) == 0
return range(first(r), step=one(eltype(r)), length=0)
elseif length(s) == 1
if first(s)
return range(first(r), step=one(eltype(r)), length=1)
else
return range(first(r), step=one(eltype(r)), length=0)
end
else # length(s) == 2
return range(r[2], step=one(eltype(r)), length=1)
end
else
st = oftype(first(r), first(r) + s.start-1)
return range(st, step=step(s), length=length(s))
end
end

function getindex(r::StepRange, s::AbstractRange{<:Integer})
function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
st = oftype(r.start, r.start + (first(s)-1)*step(r))
range(st, step=step(r)*step(s), length=length(s))

if T === Bool
if length(s) == 0
return range(first(r), step=step(r), length=0)
elseif length(s) == 1
if first(s)
return range(first(r), step=step(r), length=1)
else
return range(first(r), step=step(r), length=0)
end
else # length(s) == 2
return range(r[2], step=step(r), length=1)
end
else
st = oftype(r.start, r.start + (first(s)-1)*step(r))
return range(st, step=step(r)*step(s), length=length(s))
end
end

function getindex(r::StepRangeLen{T}, s::OrdinalRange{<:Integer}) where {T}
function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
# Find closest approach to offset by s
ind = LinearIndices(s)
offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind))
ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s))
return StepRangeLen{T}(ref, r.step*step(s), length(s), offset)

if S === Bool
if length(s) == 0
return StepRangeLen{T}(first(r), step(r), 0, 1)
elseif length(s) == 1
if first(s)
return StepRangeLen{T}(first(r), step(r), 1, 1)
else
return StepRangeLen{T}(first(r), step(r), 0, 1)
end
else # length(s) == 2
return StepRangeLen{T}(r[2], step(r), 1, 1)
end
else
# Find closest approach to offset by s
ind = LinearIndices(s)
offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind))
ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s))
return StepRangeLen{T}(ref, r.step*step(s), length(s), offset)
end
end

function getindex(r::LinRange{T}, s::OrdinalRange{<:Integer}) where {T}
function getindex(r::LinRange{T}, s::OrdinalRange{S}) where {T, S<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
vfirst = unsafe_getindex(r, first(s))
vlast = unsafe_getindex(r, last(s))
return LinRange{T}(vfirst, vlast, length(s))

if S === Bool
if length(s) == 0
return LinRange(first(r), first(r), 0)
elseif length(s) == 1
if first(s)
return LinRange(first(r), first(r), 1)
else
return LinRange(first(r), first(r), 0)
end
else # length(s) == 2
return LinRange(r[2], r[2], 1)
end
else
vfirst = unsafe_getindex(r, first(s))
vlast = unsafe_getindex(r, last(s))
return LinRange{T}(vfirst, vlast, length(s))
end
end

show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r)))
Expand Down
40 changes: 28 additions & 12 deletions base/twiceprecision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -448,34 +448,50 @@ end
function unsafe_getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, i::Integer) where T
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
@_inline_meta
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = i - r.offset
shift_hi, shift_lo = u*r.step.hi, u*r.step.lo
x_hi, x_lo = add12(r.ref.hi, shift_hi)
T(x_hi + (x_lo + (shift_lo + r.ref.lo)))
end

function _getindex_hiprec(r::StepRangeLen{<:Any,<:TwicePrecision,<:TwicePrecision}, i::Integer)
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = i - r.offset
shift_hi, shift_lo = u*r.step.hi, u*r.step.lo
x_hi, x_lo = add12(r.ref.hi, shift_hi)
x_hi, x_lo = add12(x_hi, x_lo + (shift_lo + r.ref.lo))
TwicePrecision(x_hi, x_lo)
end

function getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, s::OrdinalRange{<:Integer}) where T
function getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, s::OrdinalRange{S}) where {T, S<:Integer}
@boundscheck checkbounds(r, s)
soffset = 1 + round(Int, (r.offset - first(s))/step(s))
soffset = clamp(soffset, 1, length(s))
ioffset = first(s) + (soffset-1)*step(s)
if step(s) == 1 || length(s) < 2
newstep = r.step
else
newstep = twiceprecision(r.step*step(s), nbitslen(T, length(s), soffset))
end
if ioffset == r.offset
StepRangeLen(r.ref, newstep, length(s), max(1,soffset))
if S === Bool
if length(s) == 0
return StepRangeLen(r.ref, r.step, 0, 1)
elseif length(s) == 1
if first(s)
return StepRangeLen(r.ref, r.step, 1, 1)
else
return StepRangeLen(r.ref, r.step, 0, 1)
end
else # length(s) == 2
return StepRangeLen(r[2], step(r), 1, 1)
end
else
StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset))
soffset = 1 + round(Int, (r.offset - first(s))/step(s))
soffset = clamp(soffset, 1, length(s))
ioffset = first(s) + (soffset-1)*step(s)
if step(s) == 1 || length(s) < 2
newstep = r.step
else
newstep = twiceprecision(r.step*step(s), nbitslen(T, length(s), soffset))
end
if ioffset == r.offset
return StepRangeLen(r.ref, newstep, length(s), max(1,soffset))
else
return StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset))
end
end
end

Expand Down
Loading

0 comments on commit 6fb3558

Please sign in to comment.