Skip to content

Commit

Permalink
allow apply_type_tfunc to handle argtypes with Union
Browse files Browse the repository at this point in the history
This is an alternative to #56532 and can resolve #31909.
Currently `apply_type_tfunc` is unable to handle `Union`-argtypes with
any precision. With this change, `apply_type_tfunc` now performs
union-splitting on `Union`-argtypes and returns the merged result of
the splits.
While this can improve inference precision, we might need to be cautious
about potential inference time bloat.
  • Loading branch information
aviatesk committed Nov 21, 2024
1 parent 859c25a commit 77065b5
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 30 deletions.
64 changes: 42 additions & 22 deletions Compiler/src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1350,14 +1350,14 @@ end
T = _fieldtype_tfunc(𝕃, oβ€², f, isconcretetype(oβ€²))
T === Bottom && return Bottom
PT = Const(Pair)
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T, T), true)[1]
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T, T]), true)[1]
end
@nospecs function replacefield!_tfunc(𝕃::AbstractLattice, o, f, x, v, success_order=Symbol, failure_order=Symbol)
oβ€² = widenconst(o)
T = _fieldtype_tfunc(𝕃, oβ€², f, isconcretetype(oβ€²))
T === Bottom && return Bottom
PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T)
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T), true)[1]
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T]), true)[1]
end
@nospecs function setfieldonce!_tfunc(𝕃::AbstractLattice, o, f, v, success_order=Symbol, failure_order=Symbol)
setfield!_tfunc(𝕃, o, f, v) === Bottom && return Bottom
Expand Down Expand Up @@ -1713,8 +1713,12 @@ end
const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, :_L, :_M,
:_N, :_O, :_P, :_Q, :_R, :_S, :_T, :_U, :_V, :_W, :_X, :_Y, :_Z]

# TODO: handle e.g. apply_type(T, R::Union{Type{Int32},Type{Float64}})
@nospecs function apply_type_tfunc(𝕃::AbstractLattice, headtypetype, args...)
function apply_type_tfunc(𝕃::AbstractLattice, argtypes::Vector{Any};
max_union_splitting::Int=InferenceParams().max_union_splitting)
if isempty(argtypes)
return Bottom
end
headtypetype = argtypes[1]
headtypetype = widenslotwrapper(headtypetype)
if isa(headtypetype, Const)
headtype = headtypetype.val
Expand All @@ -1723,15 +1727,15 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
else
return Any
end
if !isempty(args) && isvarargtype(args[end])
largs = length(argtypes)
if largs > 1 && isvarargtype(argtypes[end])
return isvarargtype(headtype) ? TypeofVararg : Type
end
largs = length(args)
if headtype === Union
largs == 0 && return Const(Bottom)
largs == 1 && return Const(Bottom)
hasnonType = false
for i = 1:largs
ai = args[i]
for i = 2:largs
ai = argtypes[i]
if isa(ai, Const)
if !isa(ai.val, Type)
if isa(ai.val, TypeVar)
Expand All @@ -1750,14 +1754,14 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
end
end
end
if largs == 1 # Union{T} --> T
return tmeet(widenconst(args[1]), Union{Type,TypeVar})
if largs == 2 # Union{T} --> T
return tmeet(widenconst(argtypes[2]), Union{Type,TypeVar})
end
hasnonType && return Type
ty = Union{}
allconst = true
for i = 1:largs
ai = args[i]
for i = 2:largs
ai = argtypes[i]
if isType(ai)
aty = ai.parameters[1]
allconst &= hasuniquerep(aty)
Expand All @@ -1768,6 +1772,18 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
end
return allconst ? Const(ty) : Type{ty}
end
if 1 < unionsplitcost(𝕃, argtypes) ≀ max_union_splitting
βŠ” = join(𝕃)
rt = Bottom
for split_argtypes = switchtupleunion(𝕃, argtypes)
rt = rt βŠ” _apply_type_tfunc(𝕃, headtype, split_argtypes)
end
return rt
end
return _apply_type_tfunc(𝕃, headtype, argtypes)
end
@nospecs function _apply_type_tfunc(𝕃::AbstractLattice, headtype, argtypes::Vector{Any})
largs = length(argtypes)
istuple = headtype === Tuple
if !istuple && !isa(headtype, UnionAll) && !isvarargtype(headtype)
return Union{}
Expand All @@ -1781,20 +1797,20 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
# first push the tailing vars from headtype into outervars
outer_start, ua = 0, headtype
while isa(ua, UnionAll)
if (outer_start += 1) > largs
if (outer_start += 1) > largs - 1
push!(outervars, ua.var)
end
ua = ua.body
end
if largs > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple
if largs - 1 > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple
return Bottom # too many arguments
end
outer_start = outer_start - largs + 1
outer_start = outer_start - largs + 2

varnamectr = 1
ua = headtype
for i = 1:largs
ai = widenslotwrapper(args[i])
for i = 2:largs
ai = widenslotwrapper(argtypes[i])
if isType(ai)
aip1 = ai.parameters[1]
canconst &= !has_free_typevars(aip1)
Expand Down Expand Up @@ -1868,7 +1884,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
# If the names are known, keep the upper bound, but otherwise widen to Tuple.
# This is a widening heuristic to avoid keeping type information
# that's unlikely to be useful.
if !(uw.parameters[1] isa Tuple || (i == 2 && tparams[1] isa Tuple))
if !(uw.parameters[1] isa Tuple || (i == 3 && tparams[1] isa Tuple))
ub = Any
end
else
Expand Down Expand Up @@ -1910,7 +1926,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
# throwing errors.
appl = headtype
if isa(appl, UnionAll)
for _ = 1:largs
for _ = 2:largs
appl = appl::UnionAll
push!(outervars, appl.var)
appl = appl.body
Expand All @@ -1930,6 +1946,8 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
end
return ans
end
@nospecs apply_type_tfunc(𝕃::AbstractLattice, headtypetype, args...) =
apply_type_tfunc(𝕃, pushfirst!(collect(Any, args), headtypetype))
add_tfunc(apply_type, 1, INT_INF, apply_type_tfunc, 10)

# convert the dispatch tuple type argtype to the real (concrete) type of
Expand Down Expand Up @@ -2016,15 +2034,15 @@ end
T = _memoryref_elemtype(mem)
T === Bottom && return Bottom
PT = Const(Pair)
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T, T), true)[1]
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T, T]), true)[1]
end
@nospecs function memoryrefreplace!_tfunc(𝕃::AbstractLattice, mem, x, v, success_order, failure_order, boundscheck)
memoryrefset!_tfunc(𝕃, mem, v, success_order, boundscheck) === Bottom && return Bottom
hasintersect(widenconst(failure_order), Symbol) || return Bottom
T = _memoryref_elemtype(mem)
T === Bottom && return Bottom
PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T)
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T), true)[1]
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T]), true)[1]
end
@nospecs function memoryrefsetonce!_tfunc(𝕃::AbstractLattice, mem, v, success_order, failure_order, boundscheck)
memoryrefset!_tfunc(𝕃, mem, v, success_order, boundscheck) === Bottom && return Bottom
Expand Down Expand Up @@ -2668,6 +2686,8 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
end
end
return current_scope_tfunc(interp, sv)
elseif f === Core.apply_type
return apply_type_tfunc(𝕃ᡒ, argtypes; max_union_splitting=InferenceParams(interp).max_union_splitting)
end
fidx = find_tfunc(f)
if fidx === nothing
Expand Down
41 changes: 33 additions & 8 deletions Compiler/test/inference.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# tests for Core.Compiler correctness and precision
import Core.Compiler: Const, Conditional, βŠ‘, ReturnNode, GotoIfNot
using Core.Compiler: Conditional, βŠ‘
isdispatchelem(@nospecialize x) = !isa(x, Type) || Core.Compiler.isdispatchelem(x)

using Random, Core.IR
Expand Down Expand Up @@ -1721,7 +1721,7 @@ g_test_constant() = (f_constant(3) == 3 && f_constant(4) == 4 ? true : "BAD")
f_pure_add() = (1 + 1 == 2) ? true : "FAIL"
@test @inferred f_pure_add()

import Core: Const
using Core: Const
mutable struct ARef{T}
@atomic x::T
end
Expand Down Expand Up @@ -1762,7 +1762,7 @@ let getfield_tfunc(@nospecialize xs...) =
@test getfield_tfunc(ARef{Int},Const(:x),Bool,Bool) === Union{}
end

import Core.Compiler: Const
using Core: Const
mutable struct XY{X,Y}
x::X
y::Y
Expand Down Expand Up @@ -2767,10 +2767,10 @@ end |> only === Int

# `apply_type_tfunc` accuracy for constrained type construction
# https://github.com/JuliaLang/julia/issues/47089
import Core: Const
import Core.Compiler: apply_type_tfunc
struct Issue47089{A<:Number,B<:Number} end
let 𝕃 = Core.Compiler.fallback_lattice
let apply_type_tfunc = Core.Compiler.apply_type_tfunc
𝕃 = Core.Compiler.fallback_lattice
Const = Core.Const
A = Type{<:Integer}
@test apply_type_tfunc(𝕃, Const(Issue47089), A, A) <: (Type{Issue47089{A,B}} where {A<:Integer, B<:Integer})
@test apply_type_tfunc(𝕃, Const(Issue47089), Const(Int), Const(Int), Const(Int)) === Union{}
Expand Down Expand Up @@ -4556,7 +4556,8 @@ end |> only == Tuple{Int,Int}
end |> only == Int

# form PartialStruct for mutables with `const` field
import Core.Compiler: Const, βŠ‘
using Core: Const
using Core.Compiler: βŠ‘
mutable struct PartialMutable{S,T}
const s::S
t::T
Expand Down Expand Up @@ -5715,7 +5716,8 @@ let x = 1, _Any = Any
end

# Issue #51927
let 𝕃 = Core.Compiler.fallback_lattice
let apply_type_tfunc = Core.Compiler.apply_type_tfunc
𝕃 = Core.Compiler.fallback_lattice
@test apply_type_tfunc(𝕃, Const(Tuple{Vararg{Any,N}} where N), Int) == Type{NTuple{_A, Any}} where _A
end

Expand Down Expand Up @@ -6089,6 +6091,29 @@ function issue56387(nt::NamedTuple, field::Symbol=:a)
end
@test Base.infer_return_type(issue56387, (typeof((;a=1)),)) == Type{Int}

# `apply_type_tfunc` with `Union` in its arguments
let apply_type_tfunc = Base.Compiler.apply_type_tfunc
𝕃 = Base.Compiler.fallback_lattice
Const = Core.Const
@test apply_type_tfunc(𝕃, Any[Const(Vector), Union{Type{Int},Type{Nothing}}]) == Union{Type{Vector{Int}},Type{Vector{Nothing}}}
end

@test Base.infer_return_type((Bool,Int,)) do b, y
x = b ? 1 : missing
inner = y -> x + y
return inner(y)
end == Union{Int,Missing}

function issue31909(ys)
x = if @noinline rand(Bool)
1
else
missing
end
map(y -> x + y, ys)
end
@test Base.infer_return_type(issue31909, (Vector{Int},)) == Union{Vector{Int},Vector{Missing}}

global setglobal!_refine::Int
@test Base.infer_return_type((Integer,)) do x
setglobal!(@__MODULE__, :setglobal!_refine, x)
Expand Down

0 comments on commit 77065b5

Please sign in to comment.