Skip to content

Commit

Permalink
support broadcasting on ranges (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
aplavin authored Nov 29, 2021
1 parent 496b898 commit b17b1c5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ _colon(::Any, ::Any, start::T, step, stop::T) where {T} =
# Opt into TwicePrecision functionality
*(x::Base.TwicePrecision, y::Units) = Base.TwicePrecision(x.hi*y, x.lo*y)
*(x::Base.TwicePrecision, y::Quantity) = (x * ustrip(y)) * unit(y)
uconvert(y, x::Base.TwicePrecision) = Base.TwicePrecision(uconvert(y, x.hi), uconvert(y, x.lo))

function colon(start::T, step::T, stop::T) where (T<:Quantity{S}
where S<:Union{Float16,Float32,Float64})
# This will always return a StepRangeLen
Expand All @@ -103,6 +105,15 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::AbstractQu
broadcasted(DefaultArrayStyle{1}(), *, r, ustrip(x)) * unit(x)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::AbstractQuantity, r::AbstractRange) =
broadcasted(DefaultArrayStyle{1}(), *, ustrip(x), r) * unit(x)

const BCAST_PROPAGATE_CALLS = Union{typeof(upreferred), typeof(ustrip), Units}
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Ref{<:Units}) = r * x[]
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Ref{<:Units}, r::AbstractRange) = x[] * r
broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRangeLen) = StepRangeLen(x(r.ref), x(r.step), r.len, r.offset)
broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::StepRange) = StepRange(x(r.start), x(r.step), x(r.stop))
broadcasted(::DefaultArrayStyle{1}, x::BCAST_PROPAGATE_CALLS, r::LinRange) = LinRange(x(r.start), x(r.stop), r.len)
broadcasted(::DefaultArrayStyle{1}, ::typeof(|>), r::AbstractRange, x::Ref{<:BCAST_PROPAGATE_CALLS}) = broadcasted(DefaultArrayStyle{1}(), x[], r)

# for ambiguity resolution
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::AbstractQuantity) where T =
broadcasted(DefaultArrayStyle{1}(), *, r, ustrip(x)) * unit(x)
Expand Down
20 changes: 20 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,26 @@ end
@test @inferred(3f0m * LinRange(0.0, 1.0, 3)) === LinRange(0.0m, 3.0m, 3)
@test @inferred(1.0s * range(0.1, step=0.1, length=3)) === @inferred(range(0.1, step=0.1, length=3) * 1.0s)
end
@testset ">> broadcasting" begin
@test @inferred((1:5) .* mm) === 1mm:1mm:5mm
@test @inferred(mm .* (1:5)) === 1mm:1mm:5mm
@test @inferred((1:2:5) .* mm) === 1mm:2mm:5mm
@test @inferred((1.0:2.0:5.01) .* mm) === 1.0mm:2.0mm:5.0mm
r = @inferred(range(0.1, step=0.1, length=3) .* 1.0s)
@test r[3] === 0.3s
@test @inferred((0:2) .* 3f0m) === StepRangeLen{typeof(0f0m)}(0.0m, 3.0m, 3) # issue #477
@test @inferred(3f0m .* (0:2)) === StepRangeLen{typeof(0f0m)}(0.0m, 3.0m, 3) # issue #477
@test @inferred((0f0:2f0) .* 3f0m) === 0f0m:3f0m:6f0m
@test @inferred(3f0m .* (0.0:2.0)) === 0.0m:3.0m:6.0m
@test @inferred(LinRange(0f0, 1f0, 3) .* 3f0m) === LinRange(0f0m, 3f0m, 3)
@test @inferred(3f0m .* LinRange(0.0, 1.0, 3)) === LinRange(0.0m, 3.0m, 3)
@test @inferred(1.0s .* range(0.1, step=0.1, length=3)) === @inferred(range(0.1, step=0.1, length=3) * 1.0s)

@test @inferred((1:2:5) .* cm .|> mm) === 10mm:20mm:50mm
@test mm.((1:2:5) .* cm) === 10mm:20mm:50mm
@test @inferred((1:2:5) .* km .|> upreferred) === 1000m:2000m:5000m
@test @inferred((1:2:5) .* cm .|> mm .|> ustrip) === 10:20:50
end
end
@testset "> Arrays" begin
@testset ">> Array multiplication" begin
Expand Down

0 comments on commit b17b1c5

Please sign in to comment.