Skip to content

Commit f4c68ee

Browse files
committed
Better mapreduce for Broadcasted
1 parent 51cd56f commit f4c68ee

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

base/broadcast.jl

+13-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
166166
# methods that instead specialize on `BroadcastStyle`,
167167
# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle})
168168

169-
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple}
169+
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted
170170
f::F
171171
args::Args
172172
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`)
@@ -219,6 +219,18 @@ argtype(bc::Broadcasted) = argtype(typeof(bc))
219219
_eachindex(t::Tuple{Any}) = t[1]
220220
_eachindex(t::Tuple) = CartesianIndices(t)
221221

222+
_combined_indexstyle(::Type{Tuple{A}}) where A = IndexStyle(A)
223+
_combined_indexstyle(::Type{TT}) where TT =
224+
IndexStyle(IndexStyle(Base.tuple_type_head(TT)),
225+
_combined_indexstyle(Base.tuple_type_tail(TT)))
226+
227+
Base.IndexStyle(bc::Broadcasted) = IndexStyle(typeof(bc))
228+
Base.IndexStyle(::Type{<:Broadcasted{<:Any,Nothing,<:Any,Args}}) where {Args <: Tuple} =
229+
_combined_indexstyle(Args)
230+
Base.IndexStyle(::Type{<:Broadcasted{<:Any}}) = IndexCartesian()
231+
232+
Base.LinearIndices(bc::Broadcasted) = LinearIndices(eachindex(bc))
233+
222234
Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N
223235
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N
224236

base/reduce.jl

+12-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ else
1212
const SmallUnsigned = Union{UInt8,UInt16,UInt32}
1313
end
1414

15+
abstract type AbstractBroadcasted end
16+
const AbstractArrayOrBroadcasted = Union{AbstractArray, AbstractBroadcasted}
17+
1518
"""
1619
Base.add_sum(x, y)
1720
@@ -152,7 +155,8 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...)
152155

153156
# This is a generic implementation of `mapreduce_impl()`,
154157
# certain `op` (e.g. `min` and `max`) may have their own specialized versions.
155-
@noinline function mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer, blksize::Int)
158+
@noinline function mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted,
159+
ifirst::Integer, ilast::Integer, blksize::Int)
156160
if ifirst == ilast
157161
@inbounds a1 = A[ifirst]
158162
return mapreduce_first(f, op, a1)
@@ -175,7 +179,7 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...)
175179
end
176180
end
177181

178-
mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer) =
182+
mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted, ifirst::Integer, ilast::Integer) =
179183
mapreduce_impl(f, op, A, ifirst, ilast, pairwise_blocksize(f, op))
180184

181185
"""
@@ -296,13 +300,13 @@ The default is `reduce_first(op, f(x))`.
296300
"""
297301
mapreduce_first(f, op, x) = reduce_first(op, f(x))
298302

299-
_mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, IndexStyle(A), A)
303+
_mapreduce(f, op, A::AbstractArrayOrBroadcasted) = _mapreduce(f, op, IndexStyle(A), A)
300304

301-
function _mapreduce(f, op, ::IndexLinear, A::AbstractArray{T}) where T
305+
function _mapreduce(f, op, ::IndexLinear, A::AbstractArrayOrBroadcasted)
302306
inds = LinearIndices(A)
303307
n = length(inds)
304308
if n == 0
305-
return mapreduce_empty(f, op, T)
309+
return mapreduce_empty_iter(f, op, A, IteratorEltype(A))
306310
elseif n == 1
307311
@inbounds a1 = A[first(inds)]
308312
return mapreduce_first(f, op, a1)
@@ -323,7 +327,7 @@ end
323327

324328
mapreduce(f, op, a::Number) = mapreduce_first(f, op, a)
325329

326-
_mapreduce(f, op, ::IndexCartesian, A::AbstractArray) = mapfoldl(f, op, A)
330+
_mapreduce(f, op, ::IndexCartesian, A::AbstractArrayOrBroadcasted) = mapfoldl(f, op, A)
327331

328332
"""
329333
reduce(op, itr; [init])
@@ -473,7 +477,7 @@ isgoodzero(::typeof(max), x) = isbadzero(min, x)
473477
isgoodzero(::typeof(min), x) = isbadzero(max, x)
474478

475479
function mapreduce_impl(f, op::Union{typeof(max), typeof(min)},
476-
A::AbstractArray, first::Int, last::Int)
480+
A::AbstractArrayOrBroadcasted, first::Int, last::Int)
477481
a1 = @inbounds A[first]
478482
v1 = mapreduce_first(f, op, a1)
479483
v2 = v3 = v4 = v1
@@ -746,7 +750,7 @@ function count(pred, itr)
746750
end
747751
return n
748752
end
749-
function count(pred, a::AbstractArray)
753+
function count(pred, a::AbstractArrayOrBroadcasted)
750754
n = 0
751755
for i in eachindex(a)
752756
@inbounds n += pred(a[i])::Bool

base/reducedim.jl

+8-5
Original file line numberDiff line numberDiff line change
@@ -301,16 +301,19 @@ julia> mapreduce(isodd, |, a, dims=1)
301301
1 1 1 1
302302
```
303303
"""
304-
mapreduce(f, op, A::AbstractArray; dims=:, kw...) = _mapreduce_dim(f, op, kw.data, A, dims)
304+
mapreduce(f, op, A::AbstractArrayOrBroadcasted; dims=:, kw...) =
305+
_mapreduce_dim(f, op, kw.data, A, dims)
305306

306-
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArray, ::Colon) = mapfoldl(f, op, A; nt...)
307+
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArrayOrBroadcasted, ::Colon) =
308+
mapfoldl(f, op, A; nt...)
307309

308-
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArray, ::Colon) = _mapreduce(f, op, IndexStyle(A), A)
310+
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArrayOrBroadcasted, ::Colon) =
311+
_mapreduce(f, op, IndexStyle(A), A)
309312

310-
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArray, dims) =
313+
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArrayOrBroadcasted, dims) =
311314
mapreducedim!(f, op, reducedim_initarray(A, dims, nt.init), A)
312315

313-
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArray, dims) =
316+
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArrayOrBroadcasted, dims) =
314317
mapreducedim!(f, op, reducedim_init(f, op, A, dims), A)
315318

316319
"""

0 commit comments

Comments
 (0)