-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster mapreduce for Broadcasted #31020
Changes from all commits
f4c68ee
85f6405
cc7c05a
1a60367
ae55dc2
ece81cf
ee68a45
758e003
cec63e7
fb2ff9e
cf04cc8
cedeec6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,7 +166,7 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = | |
# methods that instead specialize on `BroadcastStyle`, | ||
# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle}) | ||
|
||
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} | ||
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted | ||
f::F | ||
args::Args | ||
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`) | ||
|
@@ -193,21 +193,25 @@ function Base.show(io::IO, bc::Broadcasted{Style}) where {Style} | |
end | ||
|
||
## Allocating the output container | ||
Base.similar(bc::Broadcasted{DefaultArrayStyle{N}}, ::Type{ElType}) where {N,ElType} = | ||
similar(Array{ElType}, axes(bc)) | ||
Base.similar(bc::Broadcasted{DefaultArrayStyle{N}}, ::Type{Bool}) where N = | ||
similar(BitArray, axes(bc)) | ||
Base.similar(bc::Broadcasted, ::Type{T}) where {T} = similar(bc, T, axes(bc)) | ||
Base.similar(::Broadcasted{DefaultArrayStyle{N}}, ::Type{ElType}, dims) where {N,ElType} = | ||
similar(Array{ElType}, dims) | ||
Base.similar(::Broadcasted{DefaultArrayStyle{N}}, ::Type{Bool}, dims) where N = | ||
similar(BitArray, dims) | ||
# In cases of conflict we fall back on Array | ||
Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{ElType}) where ElType = | ||
similar(Array{ElType}, axes(bc)) | ||
Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{Bool}) = | ||
similar(BitArray, axes(bc)) | ||
Base.similar(::Broadcasted{ArrayConflict}, ::Type{ElType}, dims) where ElType = | ||
similar(Array{ElType}, dims) | ||
Base.similar(::Broadcasted{ArrayConflict}, ::Type{Bool}, dims) = | ||
similar(BitArray, dims) | ||
|
||
@inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes) | ||
_axes(::Broadcasted, axes::Tuple) = axes | ||
@inline _axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...) | ||
_axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = () | ||
|
||
@inline Base.axes(bc::Broadcasted{<:Any, <:NTuple{N}}, d::Integer) where N = | ||
d <= N ? axes(bc)[d] : OneTo(1) | ||
|
||
BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style() | ||
BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} = | ||
throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned")) | ||
|
@@ -219,6 +223,12 @@ argtype(bc::Broadcasted) = argtype(typeof(bc)) | |
_eachindex(t::Tuple{Any}) = t[1] | ||
_eachindex(t::Tuple) = CartesianIndices(t) | ||
|
||
Base.IndexStyle(bc::Broadcasted) = IndexStyle(typeof(bc)) | ||
Base.IndexStyle(::Type{<:Broadcasted{<:Any,<:Tuple{Any}}}) = IndexLinear() | ||
Base.IndexStyle(::Type{<:Broadcasted{<:Any}}) = IndexCartesian() | ||
|
||
Base.LinearIndices(bc::Broadcasted{<:Any,<:Tuple{Any}}) = axes(bc)[1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am right that this usually won't return a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that'd probably be better. This sort of specialization should probably be on |
||
|
||
Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N | ||
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N | ||
|
||
|
@@ -564,7 +574,13 @@ end | |
@boundscheck checkbounds(bc, I) | ||
@inbounds _broadcast_getindex(bc, I) | ||
end | ||
Base.@propagate_inbounds Base.getindex(bc::Broadcasted, i1::Integer, i2::Integer, I::Integer...) = bc[CartesianIndex((i1, i2, I...))] | ||
Base.@propagate_inbounds Base.getindex( | ||
bc::Broadcasted, | ||
i1::Union{Integer,CartesianIndex}, | ||
i2::Union{Integer,CartesianIndex}, | ||
I::Union{Integer,CartesianIndex}..., | ||
) = | ||
bc[CartesianIndex((i1, i2, I...))] | ||
Base.@propagate_inbounds Base.getindex(bc::Broadcasted) = bc[CartesianIndex(())] | ||
|
||
@inline Base.checkbounds(bc::Broadcasted, I::Union{Integer,CartesianIndex}) = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,9 @@ else | |
const SmallUnsigned = Union{UInt8,UInt16,UInt32} | ||
end | ||
|
||
abstract type AbstractBroadcasted end | ||
const AbstractArrayOrBroadcasted = Union{AbstractArray, AbstractBroadcasted} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I needed to define this abstract type AbstractIndexable end
const Indexable = Union{AbstractArray, AbstractIndexable} or even abstract type Indexable{N} end
abstract type AbstractArray{T,N} <: Indexable{N} end so that it's much easier for downstream projects to use indexing-based There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having an abstract type or a trait like Though a trait would probably be better than an abstract type, since it would be more flexible. In particular, an object returned by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we go with trait direction, one option is to use struct NonIndexable <: IndexStyle end
IndexStyle(::Type) = NonIndexable() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a clever idea. I think the union here is sensible for now — that trait could possibly come later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Differing the whole designing work for the trait makes sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vtjnash mentioned there may be another way around this bootstrapping issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that the 1.4 feature freeze is approaching, I support merging the PR with the FWIW, this PR is giving us incredibly good performance to be able to reuse the pairwise summation algorithm with weights at JuliaStats/StatsBase.jl#518. Great work! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks a lot for testing this PR in a real-world scenario! |
||
|
||
""" | ||
Base.add_sum(x, y) | ||
|
||
|
@@ -227,7 +230,8 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...) | |
|
||
# This is a generic implementation of `mapreduce_impl()`, | ||
# certain `op` (e.g. `min` and `max`) may have their own specialized versions. | ||
@noinline function mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer, blksize::Int) | ||
@noinline function mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted, | ||
ifirst::Integer, ilast::Integer, blksize::Int) | ||
if ifirst == ilast | ||
@inbounds a1 = A[ifirst] | ||
return mapreduce_first(f, op, a1) | ||
|
@@ -250,7 +254,7 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...) | |
end | ||
end | ||
|
||
mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer) = | ||
mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted, ifirst::Integer, ilast::Integer) = | ||
mapreduce_impl(f, op, A, ifirst, ilast, pairwise_blocksize(f, op)) | ||
|
||
""" | ||
|
@@ -383,13 +387,13 @@ The default is `reduce_first(op, f(x))`. | |
""" | ||
mapreduce_first(f, op, x) = reduce_first(op, f(x)) | ||
|
||
_mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, IndexStyle(A), A) | ||
_mapreduce(f, op, A::AbstractArrayOrBroadcasted) = _mapreduce(f, op, IndexStyle(A), A) | ||
|
||
function _mapreduce(f, op, ::IndexLinear, A::AbstractArray{T}) where T | ||
function _mapreduce(f, op, ::IndexLinear, A::AbstractArrayOrBroadcasted) | ||
inds = LinearIndices(A) | ||
n = length(inds) | ||
if n == 0 | ||
return mapreduce_empty(f, op, T) | ||
return mapreduce_empty_iter(f, op, A, IteratorEltype(A)) | ||
elseif n == 1 | ||
@inbounds a1 = A[first(inds)] | ||
return mapreduce_first(f, op, a1) | ||
|
@@ -410,7 +414,7 @@ end | |
|
||
mapreduce(f, op, a::Number) = mapreduce_first(f, op, a) | ||
|
||
_mapreduce(f, op, ::IndexCartesian, A::AbstractArray) = mapfoldl(f, op, A) | ||
_mapreduce(f, op, ::IndexCartesian, A::AbstractArrayOrBroadcasted) = mapfoldl(f, op, A) | ||
|
||
""" | ||
reduce(op, itr; [init]) | ||
|
@@ -560,7 +564,7 @@ isgoodzero(::typeof(max), x) = isbadzero(min, x) | |
isgoodzero(::typeof(min), x) = isbadzero(max, x) | ||
|
||
function mapreduce_impl(f, op::Union{typeof(max), typeof(min)}, | ||
A::AbstractArray, first::Int, last::Int) | ||
A::AbstractArrayOrBroadcasted, first::Int, last::Int) | ||
a1 = @inbounds A[first] | ||
v1 = mapreduce_first(f, op, a1) | ||
v2 = v3 = v4 = v1 | ||
|
@@ -856,7 +860,7 @@ function count(pred, itr) | |
end | ||
return n | ||
end | ||
function count(pred, a::AbstractArray) | ||
function count(pred, a::AbstractArrayOrBroadcasted) | ||
n = 0 | ||
for i in eachindex(a) | ||
@inbounds n += pred(a[i])::Bool | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turned out my previous approach (combine
IndexStyle
of each argument) was wrong since it didn't work with, e.g.,broadcasted(+, zeros(5), zeros(1, 4))
. Each argument can beIndexLinear
even though the broadcasted indexing is not (that's the main job of broadcasted).So now the new approach is to return an equivalent of
IndexStyle(bc.axes[1])
iflength(bc.axes) == 1
and then returnIndexCartesian()
otherwise. As I needAxes
type parameter now, this implementation requires theBroadcasted
to beinstantiate
d first. I guess this is OK since that's the usual broadcasting pipeline.Is this implementation fine? Is
IndexStyle(bc.axes[1]) == IndexLinear() && length(bc.axes) == 1
the only possibleIndexLinear()
case that is detectable in a type-stable manner? That is to say, I'm ignoring the case likebroadcasted(+, zeros(1, 4), zeros(1, 4))
as I need to look at the value to check that it supports linear indexing.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't really require the
Broadcasted
to beinstantiate
d to get anIndexStyle
— it just means that if it's not instantiated that it falls back to a cartesian implementation. AllBroadcasted
s should support cartesian indexing; only one-dimensional ones will allow straight integers.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's much more accurate way to put it. I wanted to say that "for the broadcasted to be dispatched to the optimized method", it has to be
instantiate
d first.