Skip to content

Commit

Permalink
Add a transducer for Flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Oct 12, 2019
1 parent 3a266b3 commit dfcd792
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
35 changes: 24 additions & 11 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,28 @@ mul_prod(x::Real, y::Real)::Real = x * y

## foldl && mapfoldl

mapfoldl_impl(f, op, nt, itr) = foldl_impl(op, nt, Generator(f, itr))
function mapfoldl_impl(f, op, nt, itr)
op′, itr′ = _xfadjoint(BottomRF(op), Generator(f, itr))
return foldl_impl(op′, nt, itr′)
end

function foldl_impl(op, nt, itr)
op′, itr′ = _xfadjoint(BottomRF(op), itr)
return _foldl_impl(op′, nt, itr′)
v = _foldl_impl(op, get(nt, :init, _InitialValue()), itr)
v isa _InitialValue && return reduce_empty_iter(op, itr)
return v
end

function _foldl_impl(op, nt, itr)
init = get(nt, :init, _InitialValue())
function _foldl_impl(op, init, itr)
# Unroll the while loop once; if init is known, the call to op may
# be evaluated at compile time
y = iterate(itr)
if y === nothing
init isa _InitialValue && return reduce_empty_iter(op, itr)
return init
end
y === nothing && return init
v = op(init, y[1])
while true
y = iterate(itr, y[2])
y === nothing && break
v = op(v, y[1])
end
v isa _InitialValue && return reduce_empty_iter(op, itr)
return v
end

Expand Down Expand Up @@ -102,6 +101,18 @@ end

@inline (op::FilteringRF)(acc, x) = op.f(x) ? op.rf(acc, x) : acc

"""
FlatteningRF(rf) -> rf′
Create a flattening reducing function that is roughly equivalent to
`rf′(acc, x) = foldl(rf, x; init=acc)`.
"""
struct FlatteningRF{T}
rf::T
end

@inline (op::FlatteningRF)(acc, x) = _foldl_impl(op.rf, acc, x)

"""
_xfadjoint(op, itr) -> op′, itr′
Expand Down Expand Up @@ -130,6 +141,8 @@ _xfadjoint(op, itr::Generator) =
end
_xfadjoint(op, itr::Filter) =
_xfadjoint(FilteringRF(itr.flt, op), itr.itr)
_xfadjoint(op, itr::Flatten) =
_xfadjoint(FlatteningRF(op), itr.it)

"""
mapfoldl(f, op, itr; [init])
Expand Down Expand Up @@ -162,7 +175,7 @@ foldl(op, itr; kw...) = mapfoldl(identity, op, itr; kw...)

function mapfoldr_impl(f, op, nt, itr)
op′, itr′ = _xfadjoint(BottomRF(FlipArgs(op)), Generator(f, itr))
return _foldl_impl(op′, nt, Iterators.reverse(itr′))
return foldl_impl(op′, nt, Iterators.reverse(itr′))
end

struct FlipArgs{F}
Expand Down
6 changes: 1 addition & 5 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,7 @@ function map(f, t1::Any16, t2::Any16, ts::Any16...)
(A...,)
end

function _foldl_impl(op, nt, itr::Tuple)
init = get(nt, :init, _InitialValue())
y = afoldl(op, init, itr...)
return y isa _InitialValue ? reduce_empty_iter(op, itr) : y
end
_foldl_impl(op, init, itr::Tuple) = afoldl(op, init, itr...)

# type-stable padding
fill_to_length(t::NTuple{N,Any}, val, ::Val{N}) where {N} = t
Expand Down

0 comments on commit dfcd792

Please sign in to comment.