Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make StepRangeLen in operations which may produce zero step #40320

Merged
merged 14 commits into from
May 15, 2021
14 changes: 7 additions & 7 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1153,15 +1153,15 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRa
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2

broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} =
StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = StepRangeLen(x*first(r), x*step(r), length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} = StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len)
# separate in case of noncommutative multiplication
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} =
StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::AbstractFloat, r::OrdinalRange{<:Integer}) = Base.range_start_step_length(x*first(r), x, length(r))
# separate in case of noncommutative multiplication:
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = StepRangeLen(first(r)*x, step(r)*x, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} = StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::OrdinalRange{<:Integer}, x::AbstractFloat) = Base.range_start_step_length(first(r)*x, x, length(r))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last method here handles the case needed for AbstractFFTs tests, above, of float * UnitRange -> range with twiceprecision.

And (-2:2) * 0.0 works because for floats range_start_step_length does sometimes allow zero step:

julia> Base.range_start_step_length(1.0, 0, 5)
StepRangeLen(1.0, 0.0, 5)

julia> Base.range_start_step_length(1, 0, 5)  # goes to 1st method not 4th
ERROR: ArgumentError: step cannot be zero

But probably it could just be OrdinalRange. Are there other weird cases we should think about?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether Base.range_start_step_length(1, 0, 5) should just return a StepRangeLen. Do you expect that to be too breaking?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought about it. Was hoping to change as little as possible, focused on calculations that may produce zero step, rather than constructors.

This one is called by range(1, step=0, length=10) which would otherwise make an integer StepRange. Whereas range(1, 1, length=10) works, and makes a floating point range. It is a bit weird that some constructors work and some don't, but this isn't changed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, it's probably better not to change this behavior in this PR. Could you just add a short comment here explaining why we can't use range_start_step_length?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, see what you think of 56d9c27

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Let's run PkgEval one more time once tests pass, but otherwise this should be good to go.


broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} =
Expand Down
10 changes: 8 additions & 2 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -924,9 +924,15 @@ function getindex(r::LinRange{T}, s::OrdinalRange{S}) where {T, S<:Integer}
end
end

show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r)))
show(io::IO, r::UnitRange) = print(io, repr(first(r)), ':', repr(last(r)))
show(io::IO, r::OneTo) = print(io, "Base.OneTo(", r.stop, ")")
function show(io::IO, r::AbstractRange)
if step(r) != 0
print(io, repr(first(r)), ':', repr(step(r)), ':', repr(last(r)))
else
print(io, "StepRangeLen(", repr(first(r)), ", ", repr(step(r)), ", ", repr(length(r)), ")")
end
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

function ==(r::T, s::T) where {T<:AbstractRange}
isempty(r) && return isempty(s)
Expand Down Expand Up @@ -1238,7 +1244,7 @@ function _define_range_op(@nospecialize f)
r1l = length(r1)
(r1l == length(r2) ||
throw(DimensionMismatch("argument dimensions must match: length of r1 is $r1l, length of r2 is $(length(r2))")))
range($f(first(r1), first(r2)), step=$f(step(r1), step(r2)), length=r1l)
StepRangeLen($f(first(r1), first(r2)), $f(step(r1), step(r2)), r1l)
end

function $f(r1::LinRange{T}, r2::LinRange{T}) where T
Expand Down
8 changes: 5 additions & 3 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,8 @@ end
end
end

# issue 40309
@test Base.broadcasted_kwsyntax(+, [1], [2]) isa Broadcast.Broadcasted{<:Any, <:Any, typeof(+)}
@test Broadcast.BroadcastFunction(+)(2:3, 2:3) === 4:2:6
@testset "Issue #40309: still gives a range after #40320" begin
@test Base.broadcasted_kwsyntax(+, [1], [2]) isa Broadcast.Broadcasted{<:Any, <:Any, typeof(+)}
@test Broadcast.BroadcastFunction(+)(2:3, 2:3) == 4:2:6
@test Broadcast.BroadcastFunction(+)(2:3, 2:3) isa AbstractRange
end
43 changes: 30 additions & 13 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ end
@test sprint(show, StepRange(1, 2, 5)) == "1:2:5"
end

@testset "Issue 11049 and related" begin
@testset "Issue 11049, and related" begin
@test promote(range(0f0, stop=1f0, length=3), range(0., stop=5., length=2)) ===
(range(0., stop=1., length=3), range(0., stop=5., length=2))
@test convert(LinRange{Float64}, range(0., stop=1., length=3)) === LinRange(0., 1., 3)
Expand Down Expand Up @@ -1144,6 +1144,7 @@ end
@test [reverse(range(1.0, stop=27.0, length=1275));] ==
reverse([range(1.0, stop=27.0, length=1275);])
end

@testset "PR 12200 and related" begin
for _r in (1:2:100, 1:100, 1f0:2f0:100f0, 1.0:2.0:100.0,
range(1, stop=100, length=10), range(1f0, stop=100f0, length=10))
Expand Down Expand Up @@ -1288,8 +1289,8 @@ end
@test_throws BoundsError r[4]
@test_throws BoundsError r[0]
@test broadcast(+, r, 1) === 2:4
@test 2*r === 2:2:6
@test r + r === 2:2:6
@test 2*r == 2:2:6
@test r + r == 2:2:6
k = 0
for i in r
@test i == (k += 1)
Expand Down Expand Up @@ -1432,14 +1433,14 @@ end
@test @inferred(r .+ x) === 3:7
@test @inferred(r .- x) === -1:3
@test @inferred(x .- r) === 1:-1:-3
@test @inferred(x .* r) === 2:2:10
@test @inferred(r .* x) === 2:2:10
@test @inferred(x .* r) == 2:2:10
@test @inferred(r .* x) == 2:2:10
@test @inferred(r ./ x) === 0.5:0.5:2.5
@test @inferred(x ./ r) == 2 ./ [r;] && isa(x ./ r, Vector{Float64})
@test @inferred(r .\ x) == 2 ./ [r;] && isa(x ./ r, Vector{Float64})
@test @inferred(x .\ r) === 0.5:0.5:2.5

@test @inferred(2 .* (r .+ 1) .+ 2) === 6:2:14
@test @inferred(2 .* (r .+ 1) .+ 2) == 6:2:14
end

@testset "Bad range calls" begin
Expand Down Expand Up @@ -1564,17 +1565,22 @@ end # module NonStandardIntegerRangeTest
end

@testset "constant-valued ranges (issues #10391 and #29052)" begin
for r in ((1:4), (1:1:4), (1.0:4.0))
is_int = eltype(r) === Int
@test @inferred(0 * r) == [0.0, 0.0, 0.0, 0.0] broken=is_int
@test @inferred(0 .* r) == [0.0, 0.0, 0.0, 0.0] broken=is_int
@test @inferred(r + (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] broken=is_int
@test @inferred(r .+ (4:-1:1)) == [5.0, 5.0, 5.0, 5.0] broken=is_int
@testset "with $(nameof(typeof(r))) of $(eltype(r))" for r in ((1:4), (1:1:4), StepRangeLen(1,1,4), (1.0:4.0))
@test @inferred(0 * r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(0 .* r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r .* 0) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r + (4:-1:1)) == [5.0, 5.0, 5.0, 5.0]
@test @inferred(r .+ (4:-1:1)) == [5.0, 5.0, 5.0, 5.0]
@test @inferred(r - r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r .- r) == [0.0, 0.0, 0.0, 0.0]

@test @inferred(r .+ (4.0:-1:1)) == [5.0, 5.0, 5.0, 5.0]
@test @inferred(0.0 * r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(0.0 .* r) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r / Inf) == [0.0, 0.0, 0.0, 0.0]
@test @inferred(r ./ Inf) == [0.0, 0.0, 0.0, 0.0]

@test eval(Meta.parse(repr(0 * r))) == [0.0, 0.0, 0.0, 0.0]
end

@test_broken @inferred(range(0, step=0, length=4)) == [0, 0, 0, 0]
Expand All @@ -1587,7 +1593,7 @@ end
@test @inferred(range(0.0, stop=0, length=4)) == [0.0, 0.0, 0.0, 0.0]

z4 = 0.0 * (1:4)
@test @inferred(z4 .+ (1:4)) === 1.0:1.0:4.0
@test @inferred(z4 .+ (1:4)) == 1.0:1.0:4.0
@test @inferred(z4 .+ z4) === z4
end

Expand Down Expand Up @@ -1889,3 +1895,14 @@ end
@test_throws BoundsError r[true:true:false]
@test_throws BoundsError r[true:true:true]
end

@testset "PR 40320 nanosoldier" begin
@test 0.2 * (-2:2) == -0.4:0.2:0.4 # from tests of AbstractFFTs, needs Base.TwicePrecision
@test 0.2f0 * (-2:2) == Float32.(-0.4:0.2:0.4) # likewise needs Float64

@test 0.2 * (-2:1:2) == -0.4:0.2:0.4
@test 0.0 * (-2:2) == fill(0.0, 5)
@test 0.0 * (-2:1:2) == fill(0.0, 5)
@test 0 * (-2:2) == fill(0.0, 5)
@test 0 * (-2:1:2) == fill(0.0, 5)
end