Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ComposedFunction type inference regression from v1.6 LTS #45715

Closed
oschulz opened this issue Jun 16, 2022 · 15 comments
Closed

ComposedFunction type inference regression from v1.6 LTS #45715

oschulz opened this issue Jun 16, 2022 · 15 comments
Labels
compiler:inference Type inference

Comments

@oschulz
Copy link
Contributor

oschulz commented Jun 16, 2022

Julia v1.6 type inference has no trouble with this:

julia> VERSION
v"1.6.6"

julia> using Test

julia> @inferred ((-)  (-)  (-)  (-)  (-)  (-))(0)
0

Julia v1.7 and v1.8-rc1 can't type-infer even a three-element composition anymore:

julia> VERSION  # same for v"1.8.0-rc1"
v"1.7.3"

julia> using Test

julia> @inferred ((-) ∘ (-) ∘ (-))(0)
ERROR: return type Int64 does not match inferred return type Any

This makes function composition very unattractive in performance-critical code.

@oschulz
Copy link
Contributor Author

oschulz commented Jun 17, 2022

The reason seems to be that from

(c::ComposedFunction)(x...) = c.outer(c.inner(x...))

in Julia v1.6 it changed to

(c::ComposedFunction)(x...; kw...) = c.outer(c.inner(x...; kw...))

in v1.7. This apparently kills type inference here as soon as things get nested.

@oschulz
Copy link
Contributor Author

oschulz commented Jun 17, 2022

Something like this seems to fix it:

@inline (c::ComposedFunction)(x...; kw...) = call_composed(c, x, kw)
call_composed(c::F, x, kw) where {F<:ComposedFunction} = call_composed(c.outer, c.inner(x...; kw...))
call_composed(c::F, x) where {F<:ComposedFunction} = call_composed(c.outer, c.inner(x...))
call_composed(c::F, x) where {F} = c(x)
julia> VERSION
v"1.8.0-rc1"

julia> using Test

julia> @inferred ((-)  (-)  (-)  (-)  (-)  (-))(0)
0

Would a change like that be acceptable?

@oscardssmith
Copy link
Member

looks reasonable to me

@vtjnash
Copy link
Member

vtjnash commented Jun 17, 2022

I don’t see why that helps fully. It looks like you just rearranged some of the symbols until you didn’t know how to reproduce anymore. Pretty sure we may have an open issue about this recursion pattern already though? Not exactly clear why this one is failing currently.

@oschulz
Copy link
Contributor Author

oschulz commented Jun 17, 2022

A fix on the compiler side would be ideal, I guess, but that is beyond me. :-)

Speaking from a user point of view I have code that uses composition chains for time-critical tasks. I know Julia gives no guarantees on type inference, but in practice this is a problem in such situations, of course, because one may have to redesign the approach (esp. when things like Enzyme and/or GPUs come in). So pragmatically I'll take any kind of acceptable fix. ;-)

@oschulz
Copy link
Contributor Author

oschulz commented Jun 17, 2022

It looks like you just rearranged some of the symbols until you didn’t know how to reproduce anymore.

Well, there was some method to my rearranging, though I definitely don't know what happens on the compiler side exactly. What I was going for with that approach is getting the keyword args out of the outer function calls - so that c_outer (and outer-outer and so on) will not be called with an empty kw..., which the compiler doesn't seem to be able to infer anymore.

@oschulz
Copy link
Contributor Author

oschulz commented Jun 17, 2022

So I'm effectively replacing (c::ComposedFunction)(x...; kw...) = c.outer(c.inner(x...; kw...)) with (c::ComposedFunction)(x...) = c.outer(c.inner(x...)) for all outer function calls. Since the outer functions don't get passed keyword args anyway it's the same computation, of course.

I was trying to find a way to specialize on "kwargs present", that how I ended up with that approach.

@martinholters
Copy link
Member

Here's a breadcrumb for looking at what's happening in the compiler: We hit

newsig = limit_type_size(sig, comparison, hardlimit ? comparison : sv.linfo.specTypes, InferenceParams(interp).TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len)

with

sig == Tuple{Base.var"##_#95", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Base.ComposedFunction{typeof(Base.:(-)), typeof(Base.:(-))}, Int64}
comparison == Tuple{Base.var"##_#95", Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where T<:Tuple{Vararg{Any, N}} where names where N where V, Base.ComposedFunction{O, I} where I where O, Vararg{Any}}

which results in

newsig == Tuple{Base.var"##_#95", Base.Pairs{K, V, I, A} where A where I where V where K, Base.ComposedFunction{typeof(Base.:(-)), typeof(Base.:(-))}, Int64}

which is used as widened signature to force convergence of the recursion.

IIUC correctly, for a direct recursion (a method calling itself dierectly), we might not widen as aggressively by choosing a different comparison, but with the kwarg stuff, we get another indirection, leading to the observed behavior. That explains why getting rid of the kwargs before recursing helps. @vtjnash does that analysis make sense?

@oschulz
Copy link
Contributor Author

oschulz commented Jun 23, 2022

Until we have a compiler-level fix, would a PR like

@inline (c::ComposedFunction)(x...; kw...) = _call_composed(c, x, kw)
_call_composed(c::F, x, kw) where {F<:ComposedFunction} = _call_composed(c.outer, c.inner(x...; kw...))
_call_composed(c::F, x) where {F<:ComposedFunction} = _call_composed(c.outer, c.inner(x...))
_call_composed(c::F, x) where {F} = c(x)

be welcome? I did some more testing, I think it does work in general, not just in special cases. And - before:

julia> A = rand(1000); B = similar(A);

julia> @benchmark broadcast!(identity  exp  log, $B, $A)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  42.845 μs  775.911 μs  ┊ GC (min  max): 0.00%  92.70%
 Time  (median):     49.461 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   52.914 μs ±  19.168 μs  ┊ GC (mean ± σ):  0.82% ±  2.55%

After:

julia> @benchmark broadcast!(identity  exp  log, $B, $A)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  15.749 μs  68.373 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     16.551 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   17.237 μs ±  4.346 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

With "cheap" functions, the performance difference is (no surprise) quite striking. Before:

julia> @benchmark broadcast!(Base.Fix1(*,2)  Base.Fix1(*,2)  Base.Fix1(*,2), $B, $A)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  16.633 μs   1.752 ms  ┊ GC (min  max): 0.00%  98.26%
 Time  (median):     21.694 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   26.261 μs ± 48.300 μs  ┊ GC (mean ± σ):  6.68% ±  3.63%

After:

julia> @benchmark broadcast!(Base.Fix1(*,2)  Base.Fix1(*,2)  Base.Fix1(*,2), $B, $A)
BenchmarkTools.Trial: 10000 samples with 930 evaluations.
 Range (min  max):  110.305 ns  609.201 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     131.642 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   130.785 ns ±  11.182 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

@N5N3
Copy link
Member

N5N3 commented Jun 23, 2022

Your version will change the behavior if the inner returns something like (1, 2) (Doc said we should call outer((1, 2)), not outer(1, 2).)
Can't we just

(c::ComposedFunction)(x...; kw...) = call_composed(c.outer, c.inner(x...; kw...))
call_composed(c::ComposedFunction, x) = call_composed(c.outer, c.inner(x))
call_composed(c, x) = c(x)

Edit: I think we'd better do this, as this would also helps to solve #45748 (by some code reuse)

@oschulz
Copy link
Contributor Author

oschulz commented Jun 23, 2022

Your version will change the behavior if the

Ooops, yes, that wasn't intended of course.

Can't we just [...]

I like it.

@N5N3 N5N3 added the compiler:inference Type inference label Jun 24, 2022
@N5N3
Copy link
Member

N5N3 commented Jun 24, 2022

At a second thought, the above solution also might fail if the returned type of inner becomes more and more complex, or c.inner is still a ComposedFunction.
If we want to make sure eager inference on nested ComposedFunction, some trick similar to #45789 seems reasonable.

function (c::ComposedFunction)(x...; kw...)
    fs = unwrap_composed(c)
    call_composed(fs[1](x...; kw...), tail(fs)...)
end
unwrap_composed(c::ComposedFunction) = (unwrap_composed(c.inner)..., unwrap_composed(c.outer)...)
unwrap_composed(c) = (c,)
call_composed(x, f, fs...) = (@inline; call_composed(f(x), fs...))
call_composed(x, f) = f(x)

@oschulz
Copy link
Contributor Author

oschulz commented Jun 24, 2022

Good point, @N5N3 - should we start a PR for the details?

@N5N3
Copy link
Member

N5N3 commented Jun 25, 2022

I have put them together into #45789. (Base._xfadjoint could fuse call_composed and the test case there could also be used to test ComposedFunction's inference)

@oschulz
Copy link
Contributor Author

oschulz commented Jun 25, 2022

Neat, thanks @N5N3 !

pcjentsch pushed a commit to pcjentsch/julia that referenced this issue Aug 18, 2022
* Make `Fix1(f, Int)` inference-stable
* split `_xfadjoint` into `_xfadjoint_unwrap` and `_xfadjoint_wrap`
* Improve `(c::ComposedFunction)(x...)`'s inferability
* and fuse it in `Base._xfadjoint`.
* define a `Typeof` operator that will partly work around internal type-system bugs

Closes JuliaLang#45715
KristofferC pushed a commit that referenced this issue Aug 30, 2022
* Make `Fix1(f, Int)` inference-stable
* split `_xfadjoint` into `_xfadjoint_unwrap` and `_xfadjoint_wrap`
* Improve `(c::ComposedFunction)(x...)`'s inferability
* and fuse it in `Base._xfadjoint`.
* define a `Typeof` operator that will partly work around internal type-system bugs

Closes #45715

(cherry picked from commit d58289c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:inference Type inference
Projects
None yet
Development

No branches or pull requests

5 participants