diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index e451dfc88bc07..54e1edc46afed 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -362,33 +362,49 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any} return 0 end +function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::OptimizationParams, throw_blocks::Union{Nothing,BitSet}) + thiscost = 0 + if stmt isa Expr + thiscost = statement_cost(stmt, line, src, sptypes, slottypes, params, + params.unoptimize_throw_blocks && line in throw_blocks)::Int + elseif stmt isa GotoNode + # loops are generally always expensive + # but assume that forward jumps are already counted for from + # summing the cost of the not-taken branch + thiscost = stmt.label < line ? 40 : 0 + elseif stmt isa GotoIfNot + thiscost = stmt.dest < line ? 40 : 0 + end + return thiscost +end + function inline_worthy(body::Array{Any,1}, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::OptimizationParams, cost_threshold::Integer=params.inline_cost_threshold) bodycost::Int = 0 - if params.unoptimize_throw_blocks - throw_blocks = find_throw_blocks(body) - end + throw_blocks = params.unoptimize_throw_blocks ? find_throw_blocks(body) : nothing for line = 1:length(body) stmt = body[line] - if stmt isa Expr - thiscost = statement_cost(stmt, line, src, sptypes, slottypes, params, - params.unoptimize_throw_blocks && line in throw_blocks)::Int - elseif stmt isa GotoNode - # loops are generally always expensive - # but assume that forward jumps are already counted for from - # summing the cost of the not-taken branch - thiscost = stmt.label < line ? 40 : 0 - elseif stmt isa GotoIfNot - thiscost = stmt.dest < line ? 40 : 0 - else - continue - end + thiscost = statement_or_branch_cost(stmt, line, src, sptypes, slottypes, params, throw_blocks) bodycost = plus_saturate(bodycost, thiscost) bodycost > cost_threshold && return false end return true end +function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::CodeInfo, sptypes::Vector{Any}, params::OptimizationParams) + throw_blocks = params.unoptimize_throw_blocks ? find_throw_blocks(body) : nothing + maxcost = 0 + for line = 1:length(body) + stmt = body[line] + thiscost = statement_or_branch_cost(stmt, line, src, sptypes, src.slottypes, params, throw_blocks) + cost[line] = thiscost + if thiscost > maxcost + maxcost = thiscost + end + end + return maxcost +end + function is_known_call(e::Expr, @nospecialize(func), src, sptypes::Vector{Any}, slottypes::Vector{Any} = empty_slottypes) if e.head !== :call return false diff --git a/base/compiler/ssair/show.jl b/base/compiler/ssair/show.jl index 22c1d83f50edd..861c75c7d888b 100644 --- a/base/compiler/ssair/show.jl +++ b/base/compiler/ssair/show.jl @@ -327,7 +327,7 @@ end Base.show(io::IO, code::IRCode) = show_ir(io, code) -lineinfo_disabled(io::IO, linestart::String, lineidx::Int32) = "" +lineinfo_disabled(io::IO, linestart::String, idx::Int) = "" function DILineInfoPrinter(linetable::Vector, showtypes::Bool=false) context = LineInfoNode[] @@ -656,6 +656,11 @@ end # Show a single statement, code.code[idx], in the context of the whole CodeInfo. # Returns the updated value of bb_idx. +# line_info_preprinter(io::IO, indent::String, idx::Int) may print relevant info +# at the beginning of the line, and should at least print `indent`. It returns a +# string that will be printed after the final basic-block annotation. +# line_info_postprinter(io::IO, typ, used::Bool) prints the type-annotation at the end +# of the statement function show_ir_stmt(io::IO, code::CodeInfo, idx::Int, line_info_preprinter, line_info_postprinter, used::BitSet, cfg::CFG, bb_idx::Int) ds = get(io, :displaysize, (24, 80))::Tuple{Int,Int} cols = ds[2] @@ -681,7 +686,7 @@ function show_ir_stmt(io::IO, code::CodeInfo, idx::Int, line_info_preprinter, li if bb_idx > length(cfg.blocks) # If invariants are violated, print a special leader linestart = " "^(max_bb_idx_size + 2) # not inside a basic block bracket - inlining_indent = line_info_preprinter(io, linestart, code.codelocs[idx]) + inlining_indent = line_info_preprinter(io, linestart, idx) printstyled(io, "!!! ", "─"^max_bb_idx_size, color=:light_black) else bbrange = cfg.blocks[bb_idx].stmts @@ -689,7 +694,7 @@ function show_ir_stmt(io::IO, code::CodeInfo, idx::Int, line_info_preprinter, li # Print line info update linestart = idx == first(bbrange) ? " " : sprint(io -> printstyled(io, "│ ", color=:light_black), context=io) linestart *= " "^max_bb_idx_size - inlining_indent = line_info_preprinter(io, linestart, code.codelocs[idx]) + inlining_indent = line_info_preprinter(io, linestart, idx) if idx == first(bbrange) bb_idx_str = string(bb_idx) bb_pad = max_bb_idx_size - length(bb_idx_str) @@ -733,7 +738,13 @@ function show_ir_stmt(io::IO, code::CodeInfo, idx::Int, line_info_preprinter, li return bb_idx end -function show_ir(io::IO, code::CodeInfo, line_info_preprinter=DILineInfoPrinter(code.linetable), line_info_postprinter=default_expr_type_printer) +function statementidx_lineinfo_printer(f, code::CodeInfo) + printer = f(code.linetable) + return (io::IO, indent::String, idx::Int) -> printer(io, indent, idx > 0 ? code.codelocs[idx] : typemin(Int32)) +end +statementidx_lineinfo_printer(code::CodeInfo) = statementidx_lineinfo_printer(DILineInfoPrinter, code) + +function show_ir(io::IO, code::CodeInfo, line_info_preprinter=statementidx_lineinfo_printer(code), line_info_postprinter=default_expr_type_printer) ioctx = IOContext(io, :displaysize => displaysize(io)::Tuple{Int,Int}) stmts = code.code used = BitSet() @@ -748,7 +759,7 @@ function show_ir(io::IO, code::CodeInfo, line_info_preprinter=DILineInfoPrinter( end max_bb_idx_size = length(string(length(cfg.blocks))) - line_info_preprinter(io, " "^(max_bb_idx_size + 2), typemin(Int32)) + line_info_preprinter(io, " "^(max_bb_idx_size + 2), 0) nothing end diff --git a/base/reflection.jl b/base/reflection.jl index 70b149f85833f..fb4603290810b 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -1132,6 +1132,45 @@ function return_types(@nospecialize(f), @nospecialize(types=Tuple), interp=Core. return rt end +""" + print_statement_costs(io::IO, f, types) + +Print type-inferred and optimized code for `f` given argument types `types`, +prepending each line with its cost as estimated by the compiler's inlining engine. +""" +function print_statement_costs(io::IO, @nospecialize(f), @nospecialize(t); kwargs...) + if isa(f, Core.Builtin) + throw(ArgumentError("argument is not a generic function")) + end + tt = signature_type(f, t) + print_statement_costs(io, tt; kwargs...) +end + +function print_statement_costs(io::IO, @nospecialize(tt::Type); + world = get_world_counter(), + interp = Core.Compiler.NativeInterpreter(world)) + matches = _methods_by_ftype(tt, -1, world) + if matches === false + error("signature does not correspond to a generic function") + end + params = Core.Compiler.OptimizationParams(interp) + cst = Int[] + for match in matches + meth = func_for_method_checked(match.method, tt, match.sparams) + (code, ty) = Core.Compiler.typeinf_code(interp, meth, match.spec_types, match.sparams, true) + code === nothing && error("inference not successful") # inference disabled? + empty!(cst) + resize!(cst, length(code.code)) + maxcost = Core.Compiler.statement_costs!(cst, code.code, code, Any[match.sparams...], params) + nd = ndigits(maxcost) + println(io, meth) + IRShow.show_ir(io, code, (io, linestart, idx) -> (print(io, idx > 0 ? lpad(cst[idx], nd+1) : " "^(nd+1), " "); return "")) + println() + end +end + +print_statement_costs(args...; kwargs...) = print_statement_costs(stdout, args...; kwargs...) + """ which(f, types) diff --git a/base/show.jl b/base/show.jl index 9030ca0a6404b..d7bec58638512 100644 --- a/base/show.jl +++ b/base/show.jl @@ -2169,9 +2169,9 @@ module IRShow include("compiler/ssair/show.jl") const __debuginfo = Dict{Symbol, Any}( - # :full => src -> Base.IRShow.DILineInfoPrinter(src.linetable), # and add variable slot information - :source => src -> Base.IRShow.DILineInfoPrinter(src.linetable), - # :oneliner => src -> Base.IRShow.PartialLineInfoPrinter(src.linetable), + # :full => src -> Base.IRShow.statementidx_lineinfo_printer(src), # and add variable slot information + :source => src -> Base.IRShow.statementidx_lineinfo_printer(src), + # :oneliner => src -> Base.IRShow.statementidx_lineinfo_printer(Base.IRShow.PartialLineInfoPrinter, src), :none => src -> Base.IRShow.lineinfo_disabled, ) const default_debuginfo = Ref{Symbol}(:none) diff --git a/doc/src/devdocs/inference.md b/doc/src/devdocs/inference.md index 8cf3ec8a3af78..f97bb3f1594ba 100644 --- a/doc/src/devdocs/inference.md +++ b/doc/src/devdocs/inference.md @@ -97,49 +97,22 @@ dynamic dispatch, but a mere heuristic indicating that dynamic dispatch is extremely expensive. Each statement gets analyzed for its total cost in a function called -`statement_cost`. You can run this yourself by following the sketch below, -where `f` is your function and `tt` is the Tuple-type of the arguments: - -```jldoctest -# A demo on `fill(3.5, (2, 3))` -f = fill -tt = Tuple{Float64, Tuple{Int,Int}} -# Create the objects we need to interact with the compiler -opt_params = Core.Compiler.OptimizationParams() -mi = Base.method_instances(f, tt)[1] -ci = code_typed(f, tt)[1][1] -interp = Core.Compiler.NativeInterpreter() -opt = Core.Compiler.OptimizationState(mi, opt_params, interp) -# Calculate cost of each statement -cost(stmt::Expr) = Core.Compiler.statement_cost(stmt, -1, ci, opt.sptypes, opt.slottypes, opt.params) -cost(stmt) = 0 -cst = map(cost, ci.code) - -# output - -31-element Vector{Int64}: - 0 - 0 - 20 - 4 - 1 - 1 - 1 - 0 - 0 - 0 - ⋮ - 0 - 0 - 0 - 0 - 0 - 0 - 0 - 0 - 0 +`statement_cost`. You can display the cost associated with each statement +as follows: +```jldoctest; filter=r"tuple.jl:\d+" +julia> Base.print_statement_costs(stdout, map, (typeof(sqrt), Tuple{Int},)) # map(sqrt, (2,)) +map(f, t::Tuple{Any}) in Base at tuple.jl:169 + 0 1 ─ %1 = Base.getfield(_3, 1, true)::Int64 + 1 │ %2 = Base.sitofp(Float64, %1)::Float64 + 2 │ %3 = Base.lt_float(%2, 0.0)::Bool + 0 └── goto #3 if not %3 + 0 2 ─ Base.Math.throw_complex_domainerror(:sqrt, %2)::Union{} + 0 └── unreachable + 20 3 ─ %7 = Base.Math.sqrt_llvm(%2)::Float64 + 0 └── goto #4 + 0 4 ─ goto #5 + 0 5 ─ %10 = Core.tuple(%7)::Tuple{Float64} + 0 └── return %10 ``` -The output is a `Vector{Int}` holding the estimated cost of each -statement in `ci.code`. Note that `ci` includes the consequences of -inlining callees, and consequently the costs do too. +The line costs are in the left column. This includes the consequences of inlining and other forms of optimization. diff --git a/test/reflection.jl b/test/reflection.jl index 6c43d87b8bd59..ea54b833aeef0 100644 --- a/test/reflection.jl +++ b/test/reflection.jl @@ -42,6 +42,13 @@ end test_code_reflections(test_ir_reflection, code_lowered) test_code_reflections(test_ir_reflection, code_typed) +io = IOBuffer() +Base.print_statement_costs(io, map, (typeof(sqrt), Tuple{Int})) +str = String(take!(io)) +@test occursin("map(f, t::Tuple{Any})", str) +@test occursin("sitofp", str) +@test occursin(r"20 .*sqrt_llvm.*::Float64", str) + end # module ReflectionTest # isbits, isbitstype