Skip to content

Commit

Permalink
Improve foldl's stability on nested Iterators (#45789)
Browse files Browse the repository at this point in the history
* Make `Fix1(f, Int)` inference-stable
* split `_xfadjoint` into `_xfadjoint_unwrap` and `_xfadjoint_wrap`
* Improve `(c::ComposedFunction)(x...)`'s inferability
* and fuse it in `Base._xfadjoint`.
* define a `Typeof` operator that will partly work around internal type-system bugs

Closes #45715

(cherry picked from commit d58289c)
  • Loading branch information
N5N3 authored and KristofferC committed Aug 30, 2022
1 parent 98efbdf commit 8421c03
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 17 deletions.
27 changes: 21 additions & 6 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,9 @@ julia> [1:5;] |> (x->x.^2) |> sum |> inv
"""
|>(x, f) = f(x)

_stable_typeof(x) = typeof(x)
_stable_typeof(::Type{T}) where {T} = @isdefined(T) ? Type{T} : DataType

"""
f = Returns(value)
Expand All @@ -936,7 +939,7 @@ julia> f.value
struct Returns{V} <: Function
value::V
Returns{V}(value) where {V} = new{V}(value)
Returns(value) = new{Core.Typeof(value)}(value)
Returns(value) = new{_stable_typeof(value)}(value)
end

(obj::Returns)(args...; kw...) = obj.value
Expand Down Expand Up @@ -1027,7 +1030,19 @@ struct ComposedFunction{O,I} <: Function
ComposedFunction(outer, inner) = new{Core.Typeof(outer),Core.Typeof(inner)}(outer, inner)
end

(c::ComposedFunction)(x...; kw...) = c.outer(c.inner(x...; kw...))
function (c::ComposedFunction)(x...; kw...)
fs = unwrap_composed(c)
call_composed(fs[1](x...; kw...), tail(fs)...)
end
unwrap_composed(c::ComposedFunction) = (unwrap_composed(c.inner)..., unwrap_composed(c.outer)...)
unwrap_composed(c) = (maybeconstructor(c),)
call_composed(x, f, fs...) = (@inline; call_composed(f(x), fs...))
call_composed(x, f) = f(x)

struct Constructor{F} <: Function end
(::Constructor{F})(args...; kw...) where {F} = (@inline; F(args...; kw...))
maybeconstructor(::Type{F}) where {F} = Constructor{F}()
maybeconstructor(f) = f

(f) = f
(f, g) = ComposedFunction(f, g)
Expand Down Expand Up @@ -1074,8 +1089,8 @@ struct Fix1{F,T} <: Function
f::F
x::T

Fix1(f::F, x::T) where {F,T} = new{F,T}(f, x)
Fix1(f::Type{F}, x::T) where {F,T} = new{Type{F},T}(f, x)
Fix1(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x)
Fix1(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x)
end

(f::Fix1)(y) = f.f(f.x, y)
Expand All @@ -1091,8 +1106,8 @@ struct Fix2{F,T} <: Function
f::F
x::T

Fix2(f::F, x::T) where {F,T} = new{F,T}(f, x)
Fix2(f::Type{F}, x::T) where {F,T} = new{Type{F},T}(f, x)
Fix2(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x)
Fix2(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x)
end

(f::Fix2)(y) = f.f(y, f.x)
Expand Down
30 changes: 19 additions & 11 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,25 @@ what is returned is `itr′` and
op′ = (xfₙ ∘ ... ∘ xf₂ ∘ xf₁)(op)
"""
_xfadjoint(op, itr) = (op, itr)
_xfadjoint(op, itr::Generator) =
if itr.f === identity
_xfadjoint(op, itr.iter)
else
_xfadjoint(MappingRF(itr.f, op), itr.iter)
end
_xfadjoint(op, itr::Filter) =
_xfadjoint(FilteringRF(itr.flt, op), itr.itr)
_xfadjoint(op, itr::Flatten) =
_xfadjoint(FlatteningRF(op), itr.it)
function _xfadjoint(op, itr)
itr′, wrap = _xfadjoint_unwrap(itr)
wrap(op), itr′
end

_xfadjoint_unwrap(itr) = itr, identity
function _xfadjoint_unwrap(itr::Generator)
itr′, wrap = _xfadjoint_unwrap(itr.iter)
itr.f === identity && return itr′, wrap
return itr′, wrap Fix1(MappingRF, itr.f)
end
function _xfadjoint_unwrap(itr::Filter)
itr′, wrap = _xfadjoint_unwrap(itr.itr)
return itr′, wrap Fix1(FilteringRF, itr.flt)
end
function _xfadjoint_unwrap(itr::Flatten)
itr′, wrap = _xfadjoint_unwrap(itr.it)
return itr′, wrap FlatteningRF
end

"""
mapfoldl(f, op, itr; [init])
Expand Down
12 changes: 12 additions & 0 deletions test/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ Base.promote_rule(::Type{T19714}, ::Type{Int}) = T19714

end

@testset "Nested ComposedFunction's stability" begin
f(x) = (1, 1, x...)
g = (f (f f)) (f f f)
@test (@inferred (gg)(1)) == ntuple(Returns(1), 25)
@test (@inferred g(1)) == ntuple(Returns(1), 13)
h = (-) (-) (-) (-) (-) (-) sum
@test (@inferred h((1, 2, 3); init = 0.0)) == 6.0
end

@testset "function negation" begin
str = randstring(20)
@test filter(!isuppercase, str) == replace(str, r"[A-Z]" => "")
Expand Down Expand Up @@ -302,6 +311,9 @@ end
val = [1,2,3]
@test Returns(val)(1) === val
@test sprint(show, Returns(1.0)) == "Returns{Float64}(1.0)"

illtype = Vector{Core._typevar(:T, Union{}, Any)}
@test Returns(illtype) == Returns{DataType}(illtype)
end

@testset "<= (issue #46327)" begin
Expand Down
13 changes: 13 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,16 @@ end
@test mapreduce(+, +, oa, oa) == 2len
end
end

# issue #45748
@testset "foldl's stability for nested Iterators" begin
a = Iterators.flatten((1:3, 1:3))
b = (2i for i in a if i > 0)
c = Base.Generator(Float64, b)
d = (sin(i) for i in c if i > 0)
@test @inferred(sum(d)) == sum(collect(d))
@test @inferred(extrema(d)) == extrema(collect(d))
@test @inferred(maximum(c)) == maximum(collect(c))
@test @inferred(prod(b)) == prod(collect(b))
@test @inferred(minimum(a)) == minimum(collect(a))
end

0 comments on commit 8421c03

Please sign in to comment.