-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ComposedFunction type inference regression from v1.6 LTS #45715
Comments
The reason seems to be that from (c::ComposedFunction)(x...) = c.outer(c.inner(x...)) in Julia v1.6 it changed to (c::ComposedFunction)(x...; kw...) = c.outer(c.inner(x...; kw...)) in v1.7. This apparently kills type inference here as soon as things get nested. |
Something like this seems to fix it: @inline (c::ComposedFunction)(x...; kw...) = call_composed(c, x, kw)
call_composed(c::F, x, kw) where {F<:ComposedFunction} = call_composed(c.outer, c.inner(x...; kw...))
call_composed(c::F, x) where {F<:ComposedFunction} = call_composed(c.outer, c.inner(x...))
call_composed(c::F, x) where {F} = c(x) julia> VERSION
v"1.8.0-rc1"
julia> using Test
julia> @inferred ((-) ∘ (-) ∘ (-) ∘ (-) ∘ (-) ∘ (-))(0)
0 Would a change like that be acceptable? |
looks reasonable to me |
I don’t see why that helps fully. It looks like you just rearranged some of the symbols until you didn’t know how to reproduce anymore. Pretty sure we may have an open issue about this recursion pattern already though? Not exactly clear why this one is failing currently. |
A fix on the compiler side would be ideal, I guess, but that is beyond me. :-) Speaking from a user point of view I have code that uses composition chains for time-critical tasks. I know Julia gives no guarantees on type inference, but in practice this is a problem in such situations, of course, because one may have to redesign the approach (esp. when things like Enzyme and/or GPUs come in). So pragmatically I'll take any kind of acceptable fix. ;-) |
Well, there was some method to my rearranging, though I definitely don't know what happens on the compiler side exactly. What I was going for with that approach is getting the keyword args out of the outer function calls - so that |
So I'm effectively replacing I was trying to find a way to specialize on "kwargs present", that how I ended up with that approach. |
Here's a breadcrumb for looking at what's happening in the compiler: We hit julia/base/compiler/abstractinterpretation.jl Line 559 in 4c39647
with sig == Tuple{Base.var"##_#95", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Base.ComposedFunction{typeof(Base.:(-)), typeof(Base.:(-))}, Int64}
comparison == Tuple{Base.var"##_#95", Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where T<:Tuple{Vararg{Any, N}} where names where N where V, Base.ComposedFunction{O, I} where I where O, Vararg{Any}} which results in newsig == Tuple{Base.var"##_#95", Base.Pairs{K, V, I, A} where A where I where V where K, Base.ComposedFunction{typeof(Base.:(-)), typeof(Base.:(-))}, Int64} which is used as widened signature to force convergence of the recursion. IIUC correctly, for a direct recursion (a method calling itself dierectly), we might not widen as aggressively by choosing a different |
Until we have a compiler-level fix, would a PR like @inline (c::ComposedFunction)(x...; kw...) = _call_composed(c, x, kw)
_call_composed(c::F, x, kw) where {F<:ComposedFunction} = _call_composed(c.outer, c.inner(x...; kw...))
_call_composed(c::F, x) where {F<:ComposedFunction} = _call_composed(c.outer, c.inner(x...))
_call_composed(c::F, x) where {F} = c(x) be welcome? I did some more testing, I think it does work in general, not just in special cases. And - before: julia> A = rand(1000); B = similar(A);
julia> @benchmark broadcast!(identity ∘ exp ∘ log, $B, $A)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 42.845 μs … 775.911 μs ┊ GC (min … max): 0.00% … 92.70%
Time (median): 49.461 μs ┊ GC (median): 0.00%
Time (mean ± σ): 52.914 μs ± 19.168 μs ┊ GC (mean ± σ): 0.82% ± 2.55% After: julia> @benchmark broadcast!(identity ∘ exp ∘ log, $B, $A)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 15.749 μs … 68.373 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 16.551 μs ┊ GC (median): 0.00%
Time (mean ± σ): 17.237 μs ± 4.346 μs ┊ GC (mean ± σ): 0.00% ± 0.00% With "cheap" functions, the performance difference is (no surprise) quite striking. Before: julia> @benchmark broadcast!(Base.Fix1(*,2) ∘ Base.Fix1(*,2) ∘ Base.Fix1(*,2), $B, $A)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 16.633 μs … 1.752 ms ┊ GC (min … max): 0.00% … 98.26%
Time (median): 21.694 μs ┊ GC (median): 0.00%
Time (mean ± σ): 26.261 μs ± 48.300 μs ┊ GC (mean ± σ): 6.68% ± 3.63% After: julia> @benchmark broadcast!(Base.Fix1(*,2) ∘ Base.Fix1(*,2) ∘ Base.Fix1(*,2), $B, $A)
BenchmarkTools.Trial: 10000 samples with 930 evaluations.
Range (min … max): 110.305 ns … 609.201 ns ┊ GC (min … max): 0.00% … 0.00%
Time (median): 131.642 ns ┊ GC (median): 0.00%
Time (mean ± σ): 130.785 ns ± 11.182 ns ┊ GC (mean ± σ): 0.00% ± 0.00% |
Your version will change the behavior if the (c::ComposedFunction)(x...; kw...) = call_composed(c.outer, c.inner(x...; kw...))
call_composed(c::ComposedFunction, x) = call_composed(c.outer, c.inner(x))
call_composed(c, x) = c(x)
|
Ooops, yes, that wasn't intended of course.
I like it. |
At a second thought, the above solution also might fail if the returned type of function (c::ComposedFunction)(x...; kw...)
fs = unwrap_composed(c)
call_composed(fs[1](x...; kw...), tail(fs)...)
end
unwrap_composed(c::ComposedFunction) = (unwrap_composed(c.inner)..., unwrap_composed(c.outer)...)
unwrap_composed(c) = (c,)
call_composed(x, f, fs...) = (@inline; call_composed(f(x), fs...))
call_composed(x, f) = f(x) |
Good point, @N5N3 - should we start a PR for the details? |
I have put them together into #45789. ( |
Neat, thanks @N5N3 ! |
* Make `Fix1(f, Int)` inference-stable * split `_xfadjoint` into `_xfadjoint_unwrap` and `_xfadjoint_wrap` * Improve `(c::ComposedFunction)(x...)`'s inferability * and fuse it in `Base._xfadjoint`. * define a `Typeof` operator that will partly work around internal type-system bugs Closes JuliaLang#45715
* Make `Fix1(f, Int)` inference-stable * split `_xfadjoint` into `_xfadjoint_unwrap` and `_xfadjoint_wrap` * Improve `(c::ComposedFunction)(x...)`'s inferability * and fuse it in `Base._xfadjoint`. * define a `Typeof` operator that will partly work around internal type-system bugs Closes #45715 (cherry picked from commit d58289c)
Julia v1.6 type inference has no trouble with this:
Julia v1.7 and v1.8-rc1 can't type-infer even a three-element composition anymore:
This makes function composition very unattractive in performance-critical code.
The text was updated successfully, but these errors were encountered: