Skip to content

Commit

Permalink
improve performance issue of @nospecialize-d keyword func call
Browse files Browse the repository at this point in the history
This commit tries to fix and improve performance for calling keyword
funcs whose arguments types are not fully known but `@nospecialize`-d.

The final result would look like (this particular example is taken from
our Julia-level compiler implementation):
```julia
abstract type CallInfo end
struct NoCallInfo <: CallInfo end
struct NewInstruction
    stmt::Any
    type::Any
    info::CallInfo
    line::Union{Int32,Nothing} # if nothing, copy the line from previous statement in the insertion location
    flag::Union{UInt8,Nothing} # if nothing, IR flags will be recomputed on insertion
    function NewInstruction(@nospecialize(stmt), @nospecialize(type), @nospecialize(info::CallInfo),
                            line::Union{Int32,Nothing}, flag::Union{UInt8,Nothing})
        return new(stmt, type, info, line, flag)
    end
end
@nospecialize
function NewInstruction(newinst::NewInstruction;
    stmt=newinst.stmt,
    type=newinst.type,
    info::CallInfo=newinst.info,
    line::Union{Int32,Nothing}=newinst.line,
    flag::Union{UInt8,Nothing}=newinst.flag)
    return NewInstruction(stmt, type, info, line, flag)
end
@Specialize

using BenchmarkTools
struct VirtualKwargs
    stmt::Any
    type::Any
    info::CallInfo
end
vkws = VirtualKwargs(nothing, Any, NoCallInfo())
newinst = NewInstruction(nothing, Any, NoCallInfo(), nothing, nothing)
runner(newinst, vkws) = NewInstruction(newinst; vkws.stmt, vkws.type, vkws.info)
@benchmark runner($newinst, $vkws)
```

> on master
```
BenchmarkTools.Trial: 10000 samples with 186 evaluations.
 Range (min … max):  559.898 ns …   4.173 μs  ┊ GC (min … max): 0.00% … 85.29%
 Time  (median):     605.608 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   638.170 ns ± 125.080 ns  ┊ GC (mean ± σ):  0.06% ±  0.85%

  █▇▂▆▄  ▁█▇▄▂                                                  ▂
  ██████▅██████▇▇▇██████▇▇▇▆▆▅▄▅▄▂▄▄▅▇▆▆▆▆▆▅▆▆▄▄▅▅▄▃▄▄▄▅▃▅▅▆▅▆▆ █
  560 ns        Histogram: log(frequency) by time       1.23 μs <

 Memory estimate: 32 bytes, allocs estimate: 2.
```

> on this commit
```julia
BenchmarkTools.Trial: 10000 samples with 1000 evaluations.
 Range (min … max):  3.080 ns … 83.177 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     3.098 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.118 ns ±  0.885 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

       ▂▅▇█▆▅▄▂
  ▂▄▆▆▇████████▆▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▂▂▂▁▂▂▂▂▂▂▁▁▂▁▂▂▂▂▂▂▂▂▂ ▃
  3.08 ns        Histogram: frequency by time        3.19 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.
```

So for this particular case it achieves roughly 200x speed up.
This is because this commit allows inlining of a call to keyword sorter
as well as removal of `NamedTuple` call.

Especially this commit is composed of the following improvements:
- Add early return case for `structdiff`:
  This change improves the return type inference for a case when
  compared `NamedTuple`s are type unstable but there is no difference
  in their names, e.g. given two `NamedTuple{(:a,:b),T} where T<:Tuple{Any,Any}`s.
  And in such case the optimizer will remove `structdiff` and succeeding
  `pairs` calls, letting the keyword sorter to be inlined.
- Tweak the core `NamedTuple{names}(args::Tuple)` constructor so that it
  directly forms `:splatnew` allocation rather than redirects to the
  general `NamedTuple` constructor, that could be confused for abstract
  input tuple type.
- Improve `nfields_tfunc` accuracy as for abstract `NamedTuple` types.
  This improvement lets `inline_splatnew` to handle more abstract
  `NamedTuple`s, especially whose names are fully known but its fields
  tuple type is abstract.

Those improvements are combined to allow our SROA pass to optimize away
`NamedTuple` and `tuple` calls generated for keyword argument handling.
E.g. the IR for the example `NewInstruction` constructor is now fairly
optimized, like:
```julia
julia> Base.code_ircode((NewInstruction,Any,Any,CallInfo)) do newinst, stmt, type, info
           NewInstruction(newinst; stmt, type, info)
       end |> only
2 1 ── %1  = Base.getfield(_2, :line)::Union{Nothing, Int32}                    │╻╷  Type##kw
  │    %2  = Base.getfield(_2, :flag)::Union{Nothing, UInt8}                    ││┃   getproperty
  │    %3  = (isa)(%1, Nothing)::Bool                                           ││
  │    %4  = (isa)(%2, Nothing)::Bool                                           ││
  │    %5  = (Core.Intrinsics.and_int)(%3, %4)::Bool                            ││
  └───       goto #3 if not %5                                                  ││
  2 ── %7  = %new(Main.NewInstruction, _3, _4, _5, nothing, nothing)::NewInstruction   NewInstruction
  └───       goto #10                                                           ││
  3 ── %9  = (isa)(%1, Int32)::Bool                                             ││
  │    %10 = (isa)(%2, Nothing)::Bool                                           ││
  │    %11 = (Core.Intrinsics.and_int)(%9, %10)::Bool                           ││
  └───       goto #5 if not %11                                                 ││
  4 ── %13 = π (%1, Int32)                                                      ││
  │    %14 = %new(Main.NewInstruction, _3, _4, _5, %13, nothing)::NewInstruction│││╻   NewInstruction
  └───       goto #10                                                           ││
  5 ── %16 = (isa)(%1, Nothing)::Bool                                           ││
  │    %17 = (isa)(%2, UInt8)::Bool                                             ││
  │    %18 = (Core.Intrinsics.and_int)(%16, %17)::Bool                          ││
  └───       goto #7 if not %18                                                 ││
  6 ── %20 = π (%2, UInt8)                                                      ││
  │    %21 = %new(Main.NewInstruction, _3, _4, _5, nothing, %20)::NewInstruction│││╻   NewInstruction
  └───       goto #10                                                           ││
  7 ── %23 = (isa)(%1, Int32)::Bool                                             ││
  │    %24 = (isa)(%2, UInt8)::Bool                                             ││
  │    %25 = (Core.Intrinsics.and_int)(%23, %24)::Bool                          ││
  └───       goto #9 if not %25                                                 ││
  8 ── %27 = π (%1, Int32)                                                      ││
  │    %28 = π (%2, UInt8)                                                      ││
  │    %29 = %new(Main.NewInstruction, _3, _4, _5, %27, %28)::NewInstruction    │││╻   NewInstruction
  └───       goto #10                                                           ││
  9 ──       Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{}
  └───       unreachable                                                        ││
  10 ┄ %33 = φ (#2 => %7, #4 => %14, #6 => %21, #8 => %29)::NewInstruction      ││
  └───       goto #11                                                           ││
  11 ─       return %33                                                         │
   => NewInstruction
```
  • Loading branch information
aviatesk committed Oct 8, 2022
1 parent 4c0f8de commit a86b36b
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 10 deletions.
3 changes: 2 additions & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,8 @@ end

NamedTuple() = NamedTuple{(),Tuple{}}(())

NamedTuple{names}(args::Tuple) where {names} = NamedTuple{names,typeof(args)}(args)
eval(Core, :(NamedTuple{names}(args::Tuple) where {names} =
$(Expr(:splatnew, :(NamedTuple{names,typeof(args)}), :args))))

using .Intrinsics: sle_int, add_int

Expand Down
6 changes: 3 additions & 3 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2109,16 +2109,16 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
elseif ehead === :splatnew
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))
nothrow = false # TODO: More precision
if length(e.args) == 2 && isconcretetype(t) && !ismutabletype(t)
if length(e.args) == 2 && isconcretedispatch(t) && !ismutabletype(t)
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
n = fieldcount(t)
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) &&
let t = t, at = at; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
nothrow = isexact && isconcretedispatch(t)
nothrow = isexact
t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val))
elseif isa(at, PartialStruct) && at ᵢ Tuple && n == length(at.fields::Vector{Any}) &&
let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] fieldtype(t, i), 1:n); end
nothrow = isexact && isconcretedispatch(t)
nothrow = isexact
t = PartialStruct(t, at.fields::Vector{Any})
end
end
Expand Down
12 changes: 11 additions & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ function lift_leaves(compact::IncrementalCompact,
end
lift_arg!(compact, leaf, cache_key, def, 1+field, lifted_leaves)
continue
# NOTE we can enable this, but most `:splatnew` expressions are transformed into
# `:new` expressions by the inlinear
# elseif isexpr(def, :splatnew) && length(def.args) == 2 && isa(def.args[2], AnySSAValue)
# tplssa = def.args[2]::AnySSAValue
# tplexpr = compact[tplssa][:inst]
# if is_known_call(tplexpr, tuple, compact) && 1 ≤ field < length(tplexpr.args)
# lift_arg!(compact, tplssa, cache_key, tplexpr, 1+field, lifted_leaves)
# continue
# end
# return nothing
elseif is_getfield_captures(def, compact)
# Walk to new_opaque_closure
ocleaf = def.args[2]
Expand Down Expand Up @@ -469,7 +479,7 @@ function lift_arg!(
end
end
lifted_leaves[cache_key] = LiftedValue(lifted)
nothing
return nothing
end

function walk_to_def(compact::IncrementalCompact, @nospecialize(leaf))
Expand Down
16 changes: 14 additions & 2 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,15 @@ function nfields_tfunc(@nospecialize(x))
x = unwrap_unionall(widenconst(x))
isconstType(x) && return Const(nfields(x.parameters[1]))
if isa(x, DataType) && !isabstracttype(x)
if !(x.name === Tuple.name && isvatuple(x)) &&
!(x.name === _NAMEDTUPLE_NAME && !isconcretetype(x))
if x.name === Tuple.name
isvatuple(x) && return Int
return Const(length(x.types))
elseif x.name === _NAMEDTUPLE_NAME
length(x.parameters) == 2 || return Int
names = x.parameters[1]
isa(names, Tuple{Vararg{Symbol}}) || return nfields_tfunc(x.parameters[2])
return Const(length(names))
else
return Const(isdefined(x, :types) ? length(x.types) : length(x.name.names))
end
end
Expand Down Expand Up @@ -1594,6 +1601,11 @@ function apply_type_tfunc(@nospecialize(headtypetype), @nospecialize args...)
end
if istuple
return Type{<:appl}
elseif isa(appl, DataType) && appl.name === _NAMEDTUPLE_NAME && appl.parameters[1] === ()
# if the first parameter of `NamedTuple` is known to be empty tuple,
# the second argument should also be empty tuple type,
# so refine it here
return Const(NamedTuple{(),Tuple{}})
end
ans = Type{appl}
for i = length(outervars):-1:1
Expand Down
11 changes: 8 additions & 3 deletions base/namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,22 +335,27 @@ reverse(nt::NamedTuple) = NamedTuple{reverse(keys(nt))}(reverse(values(nt)))
end

"""
structdiff(a::NamedTuple{an}, b::Union{NamedTuple{bn},Type{NamedTuple{bn}}}) where {an,bn}
structdiff(a::NamedTuple, b::Union{NamedTuple,Type{NamedTuple}})
Construct a copy of named tuple `a`, except with fields that exist in `b` removed.
`b` can be a named tuple, or a type of the form `NamedTuple{field_names}`.
"""
function structdiff(a::NamedTuple{an}, b::Union{NamedTuple{bn}, Type{NamedTuple{bn}}}) where {an, bn}
if @generated
names = diff_names(an, bn)
isempty(names) && return (;) # just a fast pass
idx = Int[ fieldindex(a, names[n]) for n in 1:length(names) ]
types = Tuple{Any[ fieldtype(a, idx[n]) for n in 1:length(idx) ]...}
vals = Any[ :(getfield(a, $(idx[n]))) for n in 1:length(idx) ]
:( NamedTuple{$names,$types}(($(vals...),)) )
return :( NamedTuple{$names,$types}(($(vals...),)) )
else
names = diff_names(an, bn)
# N.B this early return is necessary to get a better type stability,
# and also allows us to cut off the cost from constructing
# potentially type unstable closure passed to the `map` below
isempty(names) && return (;)
types = Tuple{Any[ fieldtype(typeof(a), names[n]) for n in 1:length(names) ]...}
NamedTuple{names,types}(map(Fix1(getfield, a), names))
return NamedTuple{names,types}(map(n::Symbol->getfield(a, n), names))
end
end

Expand Down
11 changes: 11 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1526,6 +1526,11 @@ end
@test nfields_tfunc(Tuple{Int, Vararg{Int}}) === Int
@test nfields_tfunc(Tuple{Int, Integer}) === Const(2)
@test nfields_tfunc(Union{Tuple{Int, Float64}, Tuple{Int, Int}}) === Const(2)
@test nfields_tfunc(@NamedTuple{a::Int,b::Integer}) === Const(2)
@test nfields_tfunc(NamedTuple{(:a,:b),T} where T<:Tuple{Int,Integer}) === Const(2)
@test nfields_tfunc(NamedTuple{(:a,:b)}) === Const(2)
@test nfields_tfunc(NamedTuple{names,Tuple{Any,Any}} where names) === Const(2)
@test nfields_tfunc(Union{NamedTuple{(:a,:b)},NamedTuple{(:c,:d)}}) === Const(2)

using Core.Compiler: typeof_tfunc
@test typeof_tfunc(Tuple{Vararg{Int}}) == Type{Tuple{Vararg{Int,N}}} where N
Expand Down Expand Up @@ -2336,6 +2341,12 @@ end
# Equivalence of Const(T.instance) and T for singleton types
@test Const(nothing) Nothing && Nothing Const(nothing)

# `apply_type_tfunc` should always return accurate result for empty NamedTuple case
import Core: Const
import Core.Compiler: apply_type_tfunc
@test apply_type_tfunc(Const(NamedTuple), Const(()), Type{T} where T<:Tuple{}) === Const(typeof((;)))
@test apply_type_tfunc(Const(NamedTuple), Const(()), Type{T} where T<:Tuple) === Const(typeof((;)))

# Don't pessimize apply_type to anything worse than Type and yield Bottom for invalid Unions
@test Core.Compiler.return_type(Core.apply_type, Tuple{Type{Union}}) == Type{Union{}}
@test Core.Compiler.return_type(Core.apply_type, Tuple{Type{Union},Any}) == Type
Expand Down
44 changes: 44 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1761,3 +1761,47 @@ let interp = Core.Compiler.NativeInterpreter()
@test count(isinvoke(:*), ir.stmts.inst) == 0
@test count(iscall((ir, Core.Intrinsics.mul_int)), ir.stmts.inst) == 1
end

# inline_splatnew for abstract `NamedTuple`

@eval construct_splatnew(T, fields) = $(Expr(:splatnew, :T, :fields))
for tt = Any[(Int,Int), (Integer,Integer), (Any,Any)]
let src = code_typed1(tt) do a, b
construct_splatnew(NamedTuple{(:a,:b),typeof((a,b))}, (a,b))
end
@test count(issplatnew, src.code) == 0
@test count(isnew, src.code) == 1
end
end

# optimize away `NamedTuple`s used for handling `@nospecialize`d keyword-argument
# https://github.com/JuliaLang/julia/pull/47059
abstract type CallInfo end
struct NewInstruction
stmt::Any
type::Any
info::CallInfo
line::Int32
flag::UInt8
function NewInstruction(@nospecialize(stmt), @nospecialize(type), @nospecialize(info::CallInfo),
line::Int32, flag::UInt8)
return new(stmt, type, info, line, flag)
end
end
@nospecialize
function NewInstruction(newinst::NewInstruction;
stmt=newinst.stmt,
type=newinst.type,
info::CallInfo=newinst.info,
line::Int32=newinst.line,
flag::UInt8=newinst.flag)
return NewInstruction(stmt, type, info, line, flag)
end
@specialize
let src = code_typed1((NewInstruction,Any,Any,CallInfo)) do newinst, stmt, type, info
NewInstruction(newinst; stmt, type, info)
end
@test count(issplatnew, src.code) == 0
@test count(iscall((src,NamedTuple)), src.code) == 0
@test count(isnew, src.code) == 1
end
1 change: 1 addition & 0 deletions test/compiler/irutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code

# check if `x` is a statement with a given `head`
isnew(@nospecialize x) = isexpr(x, :new)
issplatnew(@nospecialize x) = isexpr(x, :splatnew)
isreturn(@nospecialize x) = isa(x, ReturnNode)

# check if `x` is a dynamic call of a given function
Expand Down

0 comments on commit a86b36b

Please sign in to comment.