From 495c20b6311587a2f0f78ed09b91daa14c85a3a1 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 29 Jul 2024 17:26:46 +0900 Subject: [PATCH] inference: model partially initialized structs with `PartialStruct` There is still room for improvement in the accuracy of `getfield` and `isdefined` for structs with uninitialized fields. This commit aims to enhance the accuracy of struct field defined-ness by propagating such struct as `PartialStruct` in cases where fields that might be uninitialized are confirmed to be defined. Specifically, the improvements are made in the following situations: 1. when a `:new` expression receives arguments greater than the minimum number of initialized fields. 2. when new information about the initialized fields of `x` can be obtained in the `then` branch of `if isdefined(x, :f)`. Combined with the existing optimizations, these improvements enable DCE in scenarios such as: ```julia julia> @noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing); julia> @allocated broadcast_noescape1(Ref("x")) 16 # master 0 # this PR ``` One important point to note is that, as revealed in JuliaLang/julia#48999, fields and globals can revert to `undef` during precompilation. This commit does not affect globals. Furthermore, even for fields, the refinements made by 1. and 2. are propagated along with data-flow, and field defined-ness information is only used when fields are confirmed to be initialized. Therefore, the same issues as JuliaLang/julia#48999 will not occur by this commit. --- base/compiler/abstractinterpretation.jl | 73 +++++++++------ base/compiler/ssair/passes.jl | 7 +- base/compiler/tfuncs.jl | 39 ++++++-- .../compiler/EscapeAnalysis/EscapeAnalysis.jl | 10 +-- test/compiler/inference.jl | 89 +++++++++++++++++++ 5 files changed, 171 insertions(+), 47 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 3b658534b5d35e..4314c017ee5cae 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1974,26 +1974,33 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs return Conditional(aty.slot, thentype, elsetype) end elseif f === isdefined - uty = argtypes[2] a = ssa_def_slot(fargs[2], sv) - if isa(uty, Union) && isa(a, SlotNumber) - fld = argtypes[3] - thentype = Bottom - elsetype = Bottom - for ty in uniontypes(uty) - cnd = isdefined_tfunc(๐•ƒแตข, ty, fld) - if isa(cnd, Const) - if cnd.val::Bool - thentype = thentype โŠ” ty + if isa(a, SlotNumber) + argtype2 = argtypes[2] + if isa(argtype2, Union) + fld = argtypes[3] + thentype = Bottom + elsetype = Bottom + for ty in uniontypes(argtype2) + cnd = isdefined_tfunc(๐•ƒแตข, ty, fld) + if isa(cnd, Const) + if cnd.val::Bool + thentype = thentype โŠ” ty + else + elsetype = elsetype โŠ” ty + end else + thentype = thentype โŠ” ty elsetype = elsetype โŠ” ty end - else - thentype = thentype โŠ” ty - elsetype = elsetype โŠ” ty + end + return Conditional(a, thentype, elsetype) + else + thentype = form_partially_defined_struct(argtype2, argtypes[3]) + if thentype !== nothing + return Conditional(a, thentype, argtype2) end end - return Conditional(a, thentype, elsetype) end end end @@ -2001,6 +2008,18 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs return rt end +function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name)) + obj isa Const && return nothing # nothing to refine + name isa Const || return nothing + objt0 = widenconst(obj) + objt = unwrap_unionall(objt0) + isabstracttype(objt) && return nothing + fldidx = try_compute_fieldidx(objt, name.val) + fldidx === nothing && return nothing + fldidx โ‰ค datatype_min_ninitialized(objt) && return nothing + return PartialStruct(objt0, Any[fieldtype(objt0, i) for i = 1:fldidx]) +end + function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta) na = length(argtypes) if isvarargtype(argtypes[end]) @@ -2548,20 +2567,18 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V end ats[i] = at end - # For now, don't allow: - # - Const/PartialStruct of mutables (but still allow PartialStruct of mutables - # with `const` fields if anything refined) - # - partially initialized Const/PartialStruct - if fcount == nargs - if consistent === ALWAYS_TRUE && allconst - argvals = Vector{Any}(undef, nargs) - for j in 1:nargs - argvals[j] = (ats[j]::Const).val - end - rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs)) - elseif anyrefine - rt = PartialStruct(rt, ats) + if fcount == nargs && consistent === ALWAYS_TRUE && allconst + argvals = Vector{Any}(undef, nargs) + for j in 1:nargs + argvals[j] = (ats[j]::Const).val end + rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs)) + elseif anyrefine || nargs > datatype_min_ninitialized(rt) + # propagate partially initialized struct as `PartialStruct` when: + # - any refinement information is available (`anyrefine`), or when + # - `nargs` is greater than `n_initialized` derived from the struct type + # information alone + rt = PartialStruct(rt, ats) end else rt = refine_partial_type(rt) @@ -3068,7 +3085,7 @@ end @nospecializeinfer function widenreturn_partials(๐•ƒแตข::PartialsLattice, @nospecialize(rt), info::BestguessInfo) if isa(rt, PartialStruct) fields = copy(rt.fields) - local anyrefine = false + anyrefine = !isvarargtype(rt.fields[end]) && length(rt.fields) > datatype_min_ninitialized(rt.typ) ๐•ƒ = typeinf_lattice(info.interp) โŠ = strictpartialorder(๐•ƒ) for i in 1:length(fields) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 33cda9bf27d202..37d79e2bd7b0cc 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1166,7 +1166,12 @@ struct IntermediaryCollector <: WalkerCallback intermediaries::SPCSet end function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue)) - isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id) + if !(def isa Expr) + push!(walker_callback.intermediaries, defssa.id) + if def isa PiNode + return LiftedValue(def.val) + end + end return nothing end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 28e883d83312c4..dbcc20579ab083 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -427,6 +427,15 @@ end if !ismutabletype(a1) || isconst(a1, idx) return Const(isdefined(arg1.val, idx)) end + elseif isa(arg1, PartialStruct) + nflds = length(arg1.fields) + if !isvarargtype(arg1.fields[end]) + if 1 โ‰ค idx โ‰ค nflds + return Const(true) + elseif !ismutabletype(a1) || isconst(a1, idx) + return Const(false) + end + end elseif !isvatuple(a1) fieldT = fieldtype(a1, idx) if isa(fieldT, DataType) && isbitstype(fieldT) @@ -980,27 +989,39 @@ end โŠ‘ = partialorder(๐•ƒ) # If we have s00 being a const, we can potentially refine our type-based analysis above - if isa(s00, Const) || isconstType(s00) - if !isa(s00, Const) - sv = (s00::DataType).parameters[1] - else + if isa(s00, Const) || isconstType(s00) || isa(s00, PartialStruct) + if isa(s00, Const) sv = s00.val + sty = typeof(sv) + nflds = nfields(sv) + ismod = sv isa Module + elseif isa(s00, PartialStruct) + sty = s00.typ + nflds = fieldcount_noerror(sty) + ismod = false + else + sv = (s00::DataType).parameters[1] + sty = typeof(sv) + nflds = nfields(sv) + ismod = sv isa Module end if isa(name, Const) nval = name.val if !isa(nval, Symbol) - isa(sv, Module) && return false + ismod && return false isa(nval, Int) || return false end return isdefined_tfunc(๐•ƒ, s00, name) === Const(true) end - boundscheck && return false + # If bounds checking is disabled and all fields are assigned, # we may assume that we don't throw - isa(sv, Module) && return false + @assert !boundscheck + ismod && return false name โŠ‘ Int || name โŠ‘ Symbol || return false - typeof(sv).name.n_uninitialized == 0 && return true - for i = (datatype_min_ninitialized(typeof(sv)) + 1):nfields(sv) + sty.name.n_uninitialized == 0 && return true + nflds === nothing && return false + for i = (datatype_min_ninitialized(sty)+1):nflds isdefined_tfunc(๐•ƒ, s00, Const(i)) === Const(true) || return false end return true diff --git a/test/compiler/EscapeAnalysis/EscapeAnalysis.jl b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl index d8ea8be21fe07b..31c21f72280140 100644 --- a/test/compiler/EscapeAnalysis/EscapeAnalysis.jl +++ b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl @@ -2139,21 +2139,13 @@ end # ======================== # propagate escapes imposed on call arguments -@noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing) -let result = code_escapes() do - broadcast_noescape1(Ref("Hi")) - end - i = only(findall(isnew, result.ir.stmts.stmt)) - @test !has_return_escape(result.state[SSAValue(i)]) - @test_broken !has_thrown_escape(result.state[SSAValue(i)]) # TODO `getfield(RefValue{String}, :x)` isn't safe -end @noinline broadcast_noescape2(b) = broadcast(identity, b) let result = code_escapes() do broadcast_noescape2(Ref("Hi")) end i = only(findall(isnew, result.ir.stmts.stmt)) @test_broken !has_return_escape(result.state[SSAValue(i)]) # TODO interprocedural alias analysis - @test_broken !has_thrown_escape(result.state[SSAValue(i)]) # TODO `getfield(RefValue{String}, :x)` isn't safe + @test !has_thrown_escape(result.state[SSAValue(i)]) end @noinline allescape_argument(a) = (global GV = a) # obvious escape let result = code_escapes() do diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 8b6da828af54de..ccd3985f3333ca 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -5866,3 +5866,92 @@ end bar54341(args...) = foo54341(4, args...) @test Core.Compiler.return_type(bar54341, Tuple{Vararg{Int}}) === Int + +# `PartialStruct` for partially initialized structs: +struct PartiallyInitialized1 + a; b; c + PartiallyInitialized1(a) = (@nospecialize; new(a)) + PartiallyInitialized1(a, b) = (@nospecialize; new(a, b)) + PartiallyInitialized1(a, b, c) = (@nospecialize; new(a, b, c)) +end +mutable struct PartiallyInitialized2 + a; b; c + PartiallyInitialized2(a) = (@nospecialize; new(a)) + PartiallyInitialized2(a, b) = (@nospecialize; new(a, b)) + PartiallyInitialized2(a, b, c) = (@nospecialize; new(a, b, c)) +end + +# 1. isdefined modeling for partial struct +@test Base.infer_return_type((Any,Any)) do a, b + Val(isdefined(PartiallyInitialized1(a, b), :b)) +end == Val{true} +@test Base.infer_return_type((Any,Any,)) do a, b + Val(isdefined(PartiallyInitialized1(a, b), :c)) +end == Val{false} +@test Base.infer_return_type((Any,Any,Any)) do a, b, c + Val(isdefined(PartiallyInitialized1(a, b, c), :c)) +end == Val{true} +@test Base.infer_return_type((Any,Any)) do a, b + Val(isdefined(PartiallyInitialized2(a, b), :b)) +end == Val{true} +@test Base.infer_return_type((Any,Any,)) do a, b + Val(isdefined(PartiallyInitialized2(a, b), :c)) +end >: Val{false} +@test Base.infer_return_type((Any,Any,Any)) do a, b, c + s = PartiallyInitialized2(a, b) + s.c = c + Val(isdefined(s, :c)) +end >: Val{true} +@test Base.infer_return_type((Any,Any,Any)) do a, b, c + Val(isdefined(PartiallyInitialized2(a, b, c), :c)) +end == Val{true} +@test Base.infer_return_type((Vector{Int},)) do xs + Val(isdefined(tuple(1, xs...), 1)) +end == Val{true} +@test Base.infer_return_type((Vector{Int},)) do xs + Val(isdefined(tuple(1, xs...), 2)) +end == Val + +# 2. getfield modeling for partial struct +@test Base.infer_effects((Any,Any); optimize=false) do a, b + getfield(PartiallyInitialized1(a, b), :b) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Symbol,); optimize=false) do a, b, f + getfield(PartiallyInitialized1(a, b), f, #=boundscheck=#false) +end |> !Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Any); optimize=false) do a, b, c + getfield(PartiallyInitialized1(a, b, c), :c) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Any,Symbol); optimize=false) do a, b, c, f + getfield(PartiallyInitialized1(a, b, c), f, #=boundscheck=#false) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any); optimize=false) do a, b + getfield(PartiallyInitialized2(a, b), :b) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Symbol,); optimize=false) do a, b, f + getfield(PartiallyInitialized2(a, b), f, #=boundscheck=#false) +end |> !Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Any); optimize=false) do a, b, c + getfield(PartiallyInitialized2(a, b, c), :c) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Any,Symbol); optimize=false) do a, b, c, f + getfield(PartiallyInitialized2(a, b, c), f, #=boundscheck=#false) +end |> Core.Compiler.is_nothrow + +# isdefined-Conditionals +@test Base.infer_effects((Base.RefValue{Any},)) do x + if isdefined(x, :x) + return getfield(x, :x) + end +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Base.RefValue{Any},)) do x + if isassigned(x) + return x[] + end +end |> Core.Compiler.is_nothrow + +# End to end test case for the partially initialized struct with `PartialStruct` +@noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing) +@test fully_eliminated() do + broadcast_noescape1(Ref("x")) +end