Skip to content

Commit dcd8170

Browse files
committed
allow apply_type_tfunc to handle argtypes with Union
This is an alternative to #56532 and can resolve #31909. Currently `apply_type_tfunc` is unable to handle `Union`-argtypes with any precision. With this change, `apply_type_tfunc` now performs union-splitting on `Union`-argtypes and returns the merged result of the splits. While this can improve inference precision, we might need to be cautious about potential inference time bloat.
1 parent 4ed8814 commit dcd8170

File tree

2 files changed

+73
-29
lines changed

2 files changed

+73
-29
lines changed

Compiler/src/tfuncs.jl

+42-22
Original file line numberDiff line numberDiff line change
@@ -1350,14 +1350,14 @@ end
13501350
T = _fieldtype_tfunc(𝕃, o′, f, isconcretetype(o′))
13511351
T === Bottom && return Bottom
13521352
PT = Const(Pair)
1353-
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T, T), true)[1]
1353+
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T, T]), true)[1]
13541354
end
13551355
@nospecs function replacefield!_tfunc(𝕃::AbstractLattice, o, f, x, v, success_order=Symbol, failure_order=Symbol)
13561356
o′ = widenconst(o)
13571357
T = _fieldtype_tfunc(𝕃, o′, f, isconcretetype(o′))
13581358
T === Bottom && return Bottom
13591359
PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T)
1360-
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T), true)[1]
1360+
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T]), true)[1]
13611361
end
13621362
@nospecs function setfieldonce!_tfunc(𝕃::AbstractLattice, o, f, v, success_order=Symbol, failure_order=Symbol)
13631363
setfield!_tfunc(𝕃, o, f, v) === Bottom && return Bottom
@@ -1713,8 +1713,12 @@ end
17131713
const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, :_L, :_M,
17141714
:_N, :_O, :_P, :_Q, :_R, :_S, :_T, :_U, :_V, :_W, :_X, :_Y, :_Z]
17151715

1716-
# TODO: handle e.g. apply_type(T, R::Union{Type{Int32},Type{Float64}})
1717-
@nospecs function apply_type_tfunc(𝕃::AbstractLattice, headtypetype, args...)
1716+
function apply_type_tfunc(𝕃::AbstractLattice, argtypes::Vector{Any};
1717+
max_union_splitting::Int=InferenceParams().max_union_splitting)
1718+
if isempty(argtypes)
1719+
return Bottom
1720+
end
1721+
headtypetype = argtypes[1]
17181722
headtypetype = widenslotwrapper(headtypetype)
17191723
if isa(headtypetype, Const)
17201724
headtype = headtypetype.val
@@ -1723,15 +1727,15 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
17231727
else
17241728
return Any
17251729
end
1726-
if !isempty(args) && isvarargtype(args[end])
1730+
largs = length(argtypes)
1731+
if largs > 1 && isvarargtype(argtypes[end])
17271732
return isvarargtype(headtype) ? TypeofVararg : Type
17281733
end
1729-
largs = length(args)
17301734
if headtype === Union
1731-
largs == 0 && return Const(Bottom)
1735+
largs == 1 && return Const(Bottom)
17321736
hasnonType = false
1733-
for i = 1:largs
1734-
ai = args[i]
1737+
for i = 2:largs
1738+
ai = argtypes[i]
17351739
if isa(ai, Const)
17361740
if !isa(ai.val, Type)
17371741
if isa(ai.val, TypeVar)
@@ -1750,14 +1754,14 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
17501754
end
17511755
end
17521756
end
1753-
if largs == 1 # Union{T} --> T
1754-
return tmeet(widenconst(args[1]), Union{Type,TypeVar})
1757+
if largs == 2 # Union{T} --> T
1758+
return tmeet(widenconst(argtypes[2]), Union{Type,TypeVar})
17551759
end
17561760
hasnonType && return Type
17571761
ty = Union{}
17581762
allconst = true
1759-
for i = 1:largs
1760-
ai = args[i]
1763+
for i = 2:largs
1764+
ai = argtypes[i]
17611765
if isType(ai)
17621766
aty = ai.parameters[1]
17631767
allconst &= hasuniquerep(aty)
@@ -1768,6 +1772,18 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
17681772
end
17691773
return allconst ? Const(ty) : Type{ty}
17701774
end
1775+
if 1 < unionsplitcost(𝕃, argtypes) max_union_splitting
1776+
= join(𝕃)
1777+
rt = Bottom
1778+
for split_argtypes = switchtupleunion(𝕃, argtypes)
1779+
rt = rt _apply_type_tfunc(𝕃, headtype, split_argtypes)
1780+
end
1781+
return rt
1782+
end
1783+
return _apply_type_tfunc(𝕃, headtype, argtypes)
1784+
end
1785+
@nospecs function _apply_type_tfunc(𝕃::AbstractLattice, headtype, argtypes::Vector{Any})
1786+
largs = length(argtypes)
17711787
istuple = headtype === Tuple
17721788
if !istuple && !isa(headtype, UnionAll) && !isvarargtype(headtype)
17731789
return Union{}
@@ -1781,20 +1797,20 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
17811797
# first push the tailing vars from headtype into outervars
17821798
outer_start, ua = 0, headtype
17831799
while isa(ua, UnionAll)
1784-
if (outer_start += 1) > largs
1800+
if (outer_start += 1) > largs - 1
17851801
push!(outervars, ua.var)
17861802
end
17871803
ua = ua.body
17881804
end
1789-
if largs > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple
1805+
if largs - 1 > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple
17901806
return Bottom # too many arguments
17911807
end
1792-
outer_start = outer_start - largs + 1
1808+
outer_start = outer_start - largs + 2
17931809

17941810
varnamectr = 1
17951811
ua = headtype
1796-
for i = 1:largs
1797-
ai = widenslotwrapper(args[i])
1812+
for i = 2:largs
1813+
ai = widenslotwrapper(argtypes[i])
17981814
if isType(ai)
17991815
aip1 = ai.parameters[1]
18001816
canconst &= !has_free_typevars(aip1)
@@ -1868,7 +1884,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
18681884
# If the names are known, keep the upper bound, but otherwise widen to Tuple.
18691885
# This is a widening heuristic to avoid keeping type information
18701886
# that's unlikely to be useful.
1871-
if !(uw.parameters[1] isa Tuple || (i == 2 && tparams[1] isa Tuple))
1887+
if !(uw.parameters[1] isa Tuple || (i == 3 && tparams[1] isa Tuple))
18721888
ub = Any
18731889
end
18741890
else
@@ -1910,7 +1926,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
19101926
# throwing errors.
19111927
appl = headtype
19121928
if isa(appl, UnionAll)
1913-
for _ = 1:largs
1929+
for _ = 2:largs
19141930
appl = appl::UnionAll
19151931
push!(outervars, appl.var)
19161932
appl = appl.body
@@ -1930,6 +1946,8 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
19301946
end
19311947
return ans
19321948
end
1949+
@nospecs apply_type_tfunc(𝕃::AbstractLattice, headtypetype, args...) =
1950+
apply_type_tfunc(𝕃, pushfirst!(collect(Any, args), headtypetype))
19331951
add_tfunc(apply_type, 1, INT_INF, apply_type_tfunc, 10)
19341952

19351953
# convert the dispatch tuple type argtype to the real (concrete) type of
@@ -2016,15 +2034,15 @@ end
20162034
T = _memoryref_elemtype(mem)
20172035
T === Bottom && return Bottom
20182036
PT = Const(Pair)
2019-
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T, T), true)[1]
2037+
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T, T]), true)[1]
20202038
end
20212039
@nospecs function memoryrefreplace!_tfunc(𝕃::AbstractLattice, mem, x, v, success_order, failure_order, boundscheck)
20222040
memoryrefset!_tfunc(𝕃, mem, v, success_order, boundscheck) === Bottom && return Bottom
20232041
hasintersect(widenconst(failure_order), Symbol) || return Bottom
20242042
T = _memoryref_elemtype(mem)
20252043
T === Bottom && return Bottom
20262044
PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T)
2027-
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T), true)[1]
2045+
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T]), true)[1]
20282046
end
20292047
@nospecs function memoryrefsetonce!_tfunc(𝕃::AbstractLattice, mem, v, success_order, failure_order, boundscheck)
20302048
memoryrefset!_tfunc(𝕃, mem, v, success_order, boundscheck) === Bottom && return Bottom
@@ -2668,6 +2686,8 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
26682686
end
26692687
end
26702688
return current_scope_tfunc(interp, sv)
2689+
elseif f === Core.apply_type
2690+
return apply_type_tfunc(𝕃ᵢ, argtypes; max_union_splitting=InferenceParams(interp).max_union_splitting)
26712691
end
26722692
fidx = find_tfunc(f)
26732693
if fidx === nothing

Compiler/test/inference.jl

+31-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
# tests for Core.Compiler correctness and precision
4-
import Core.Compiler: Const, Conditional, , ReturnNode, GotoIfNot
4+
using Core.Compiler: Conditional,
55
isdispatchelem(@nospecialize x) = !isa(x, Type) || Core.Compiler.isdispatchelem(x)
66

77
using Random, Core.IR
@@ -1721,7 +1721,7 @@ g_test_constant() = (f_constant(3) == 3 && f_constant(4) == 4 ? true : "BAD")
17211721
f_pure_add() = (1 + 1 == 2) ? true : "FAIL"
17221722
@test @inferred f_pure_add()
17231723

1724-
import Core: Const
1724+
using Core: Const
17251725
mutable struct ARef{T}
17261726
@atomic x::T
17271727
end
@@ -1762,7 +1762,7 @@ let getfield_tfunc(@nospecialize xs...) =
17621762
@test getfield_tfunc(ARef{Int},Const(:x),Bool,Bool) === Union{}
17631763
end
17641764

1765-
import Core.Compiler: Const
1765+
using Core: Const
17661766
mutable struct XY{X,Y}
17671767
x::X
17681768
y::Y
@@ -2767,10 +2767,10 @@ end |> only === Int
27672767

27682768
# `apply_type_tfunc` accuracy for constrained type construction
27692769
# https://github.com/JuliaLang/julia/issues/47089
2770-
import Core: Const
2771-
import Core.Compiler: apply_type_tfunc
27722770
struct Issue47089{A<:Number,B<:Number} end
2773-
let 𝕃 = Core.Compiler.fallback_lattice
2771+
let apply_type_tfunc = Core.Compiler.apply_type_tfunc
2772+
𝕃 = Core.Compiler.fallback_lattice
2773+
Const = Core.Const
27742774
A = Type{<:Integer}
27752775
@test apply_type_tfunc(𝕃, Const(Issue47089), A, A) <: (Type{Issue47089{A,B}} where {A<:Integer, B<:Integer})
27762776
@test apply_type_tfunc(𝕃, Const(Issue47089), Const(Int), Const(Int), Const(Int)) === Union{}
@@ -4556,7 +4556,8 @@ end |> only == Tuple{Int,Int}
45564556
end |> only == Int
45574557

45584558
# form PartialStruct for mutables with `const` field
4559-
import Core.Compiler: Const,
4559+
using Core: Const
4560+
using Core.Compiler:
45604561
mutable struct PartialMutable{S,T}
45614562
const s::S
45624563
t::T
@@ -6088,3 +6089,26 @@ function issue56387(nt::NamedTuple, field::Symbol=:a)
60886089
types[index]
60896090
end
60906091
@test Base.infer_return_type(issue56387, (typeof((;a=1)),)) == Type{Int}
6092+
6093+
# `apply_type_tfunc` with `Union` in its arguments
6094+
let apply_type_tfunc = Base.Compiler.apply_type_tfunc
6095+
𝕃 = Base.Compiler.fallback_lattice
6096+
Const = Core.Const
6097+
@test apply_type_tfunc(𝕃, Any[Const(Vector), Union{Type{Int},Type{Nothing}}]) == Union{Type{Vector{Int}},Type{Vector{Nothing}}}
6098+
end
6099+
6100+
@test Base.infer_return_type((Bool,Int,)) do b, y
6101+
x = b ? 1 : missing
6102+
inner = y -> x + y
6103+
return inner(y)
6104+
end == Union{Int,Missing}
6105+
6106+
function issue31909(ys)
6107+
x = if @noinline rand(Bool)
6108+
1
6109+
else
6110+
missing
6111+
end
6112+
map(y -> x + y, ys)
6113+
end
6114+
@test Base.infer_return_type(issue31909, (Vector{Int},)) == Union{Vector{Int},Vector{Missing}}

0 commit comments

Comments
 (0)