Skip to content

Commit 1bf2ef9

Browse files
aviateskvtjnash
andauthored
allow apply_type_tfunc to handle argtypes with Union (#56617)
This is an alternative to #56532 and can resolve #31909. Currently `apply_type_tfunc` is unable to handle `Union`-argtypes with any precision. With this change, `apply_type_tfunc` now performs union-splitting on `Union`-argtypes and returns the merged result of the splits. While this can improve inference precision, we might need to be cautious about potential inference time bloat. --------- Co-authored-by: Jameson Nash <[email protected]>
1 parent e624440 commit 1bf2ef9

File tree

2 files changed

+76
-31
lines changed

2 files changed

+76
-31
lines changed

Compiler/src/tfuncs.jl

+42-22
Original file line numberDiff line numberDiff line change
@@ -1350,14 +1350,14 @@ end
13501350
T = _fieldtype_tfunc(𝕃, o′, f, isconcretetype(o′))
13511351
T === Bottom && return Bottom
13521352
PT = Const(Pair)
1353-
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T, T), true)[1]
1353+
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T, T]), true)[1]
13541354
end
13551355
@nospecs function replacefield!_tfunc(𝕃::AbstractLattice, o, f, x, v, success_order=Symbol, failure_order=Symbol)
13561356
o′ = widenconst(o)
13571357
T = _fieldtype_tfunc(𝕃, o′, f, isconcretetype(o′))
13581358
T === Bottom && return Bottom
13591359
PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T)
1360-
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T), true)[1]
1360+
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T]), true)[1]
13611361
end
13621362
@nospecs function setfieldonce!_tfunc(𝕃::AbstractLattice, o, f, v, success_order=Symbol, failure_order=Symbol)
13631363
setfield!_tfunc(𝕃, o, f, v) === Bottom && return Bottom
@@ -1713,8 +1713,12 @@ end
17131713
const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, :_L, :_M,
17141714
:_N, :_O, :_P, :_Q, :_R, :_S, :_T, :_U, :_V, :_W, :_X, :_Y, :_Z]
17151715

1716-
# TODO: handle e.g. apply_type(T, R::Union{Type{Int32},Type{Float64}})
1717-
@nospecs function apply_type_tfunc(𝕃::AbstractLattice, headtypetype, args...)
1716+
function apply_type_tfunc(𝕃::AbstractLattice, argtypes::Vector{Any};
1717+
max_union_splitting::Int=InferenceParams().max_union_splitting)
1718+
if isempty(argtypes)
1719+
return Bottom
1720+
end
1721+
headtypetype = argtypes[1]
17181722
headtypetype = widenslotwrapper(headtypetype)
17191723
if isa(headtypetype, Const)
17201724
headtype = headtypetype.val
@@ -1723,15 +1727,15 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
17231727
else
17241728
return Any
17251729
end
1726-
if !isempty(args) && isvarargtype(args[end])
1730+
largs = length(argtypes)
1731+
if largs > 1 && isvarargtype(argtypes[end])
17271732
return isvarargtype(headtype) ? TypeofVararg : Type
17281733
end
1729-
largs = length(args)
17301734
if headtype === Union
1731-
largs == 0 && return Const(Bottom)
1735+
largs == 1 && return Const(Bottom)
17321736
hasnonType = false
1733-
for i = 1:largs
1734-
ai = args[i]
1737+
for i = 2:largs
1738+
ai = argtypes[i]
17351739
if isa(ai, Const)
17361740
if !isa(ai.val, Type)
17371741
if isa(ai.val, TypeVar)
@@ -1750,14 +1754,14 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
17501754
end
17511755
end
17521756
end
1753-
if largs == 1 # Union{T} --> T
1754-
return tmeet(widenconst(args[1]), Union{Type,TypeVar})
1757+
if largs == 2 # Union{T} --> T
1758+
return tmeet(widenconst(argtypes[2]), Union{Type,TypeVar})
17551759
end
17561760
hasnonType && return Type
17571761
ty = Union{}
17581762
allconst = true
1759-
for i = 1:largs
1760-
ai = args[i]
1763+
for i = 2:largs
1764+
ai = argtypes[i]
17611765
if isType(ai)
17621766
aty = ai.parameters[1]
17631767
allconst &= hasuniquerep(aty)
@@ -1768,6 +1772,18 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
17681772
end
17691773
return allconst ? Const(ty) : Type{ty}
17701774
end
1775+
if 1 < unionsplitcost(𝕃, argtypes) max_union_splitting
1776+
rt = Bottom
1777+
for split_argtypes = switchtupleunion(𝕃, argtypes)
1778+
this_rt = widenconst(_apply_type_tfunc(𝕃, headtype, split_argtypes))
1779+
rt = Union{rt, this_rt}
1780+
end
1781+
return rt
1782+
end
1783+
return _apply_type_tfunc(𝕃, headtype, argtypes)
1784+
end
1785+
@nospecs function _apply_type_tfunc(𝕃::AbstractLattice, headtype, argtypes::Vector{Any})
1786+
largs = length(argtypes)
17711787
istuple = headtype === Tuple
17721788
if !istuple && !isa(headtype, UnionAll) && !isvarargtype(headtype)
17731789
return Union{}
@@ -1781,20 +1797,20 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
17811797
# first push the tailing vars from headtype into outervars
17821798
outer_start, ua = 0, headtype
17831799
while isa(ua, UnionAll)
1784-
if (outer_start += 1) > largs
1800+
if (outer_start += 1) > largs - 1
17851801
push!(outervars, ua.var)
17861802
end
17871803
ua = ua.body
17881804
end
1789-
if largs > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple
1805+
if largs - 1 > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple
17901806
return Bottom # too many arguments
17911807
end
1792-
outer_start = outer_start - largs + 1
1808+
outer_start = outer_start - largs + 2
17931809

17941810
varnamectr = 1
17951811
ua = headtype
1796-
for i = 1:largs
1797-
ai = widenslotwrapper(args[i])
1812+
for i = 2:largs
1813+
ai = widenslotwrapper(argtypes[i])
17981814
if isType(ai)
17991815
aip1 = ai.parameters[1]
18001816
canconst &= !has_free_typevars(aip1)
@@ -1868,7 +1884,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
18681884
# If the names are known, keep the upper bound, but otherwise widen to Tuple.
18691885
# This is a widening heuristic to avoid keeping type information
18701886
# that's unlikely to be useful.
1871-
if !(uw.parameters[1] isa Tuple || (i == 2 && tparams[1] isa Tuple))
1887+
if !(uw.parameters[1] isa Tuple || (i == 3 && tparams[1] isa Tuple))
18721888
ub = Any
18731889
end
18741890
else
@@ -1910,7 +1926,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
19101926
# throwing errors.
19111927
appl = headtype
19121928
if isa(appl, UnionAll)
1913-
for _ = 1:largs
1929+
for _ = 2:largs
19141930
appl = appl::UnionAll
19151931
push!(outervars, appl.var)
19161932
appl = appl.body
@@ -1930,6 +1946,8 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
19301946
end
19311947
return ans
19321948
end
1949+
@nospecs apply_type_tfunc(𝕃::AbstractLattice, headtypetype, args...) =
1950+
apply_type_tfunc(𝕃, Any[i == 0 ? headtypetype : args[i] for i in 0:length(args)])
19331951
add_tfunc(apply_type, 1, INT_INF, apply_type_tfunc, 10)
19341952

19351953
# convert the dispatch tuple type argtype to the real (concrete) type of
@@ -2016,15 +2034,15 @@ end
20162034
T = _memoryref_elemtype(mem)
20172035
T === Bottom && return Bottom
20182036
PT = Const(Pair)
2019-
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T, T), true)[1]
2037+
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T, T]), true)[1]
20202038
end
20212039
@nospecs function memoryrefreplace!_tfunc(𝕃::AbstractLattice, mem, x, v, success_order, failure_order, boundscheck)
20222040
memoryrefset!_tfunc(𝕃, mem, v, success_order, boundscheck) === Bottom && return Bottom
20232041
hasintersect(widenconst(failure_order), Symbol) || return Bottom
20242042
T = _memoryref_elemtype(mem)
20252043
T === Bottom && return Bottom
20262044
PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T)
2027-
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T), true)[1]
2045+
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T]), true)[1]
20282046
end
20292047
@nospecs function memoryrefsetonce!_tfunc(𝕃::AbstractLattice, mem, v, success_order, failure_order, boundscheck)
20302048
memoryrefset!_tfunc(𝕃, mem, v, success_order, boundscheck) === Bottom && return Bottom
@@ -2666,6 +2684,8 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
26662684
end
26672685
end
26682686
return current_scope_tfunc(interp, sv)
2687+
elseif f === Core.apply_type
2688+
return apply_type_tfunc(𝕃ᵢ, argtypes; max_union_splitting=InferenceParams(interp).max_union_splitting)
26692689
end
26702690
fidx = find_tfunc(f)
26712691
if fidx === nothing

Compiler/test/inference.jl

+34-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
include("irutils.jl")
44

55
# tests for Compiler correctness and precision
6-
import .Compiler: Const, Conditional, , ReturnNode, GotoIfNot
6+
using .Compiler: Conditional,
77
isdispatchelem(@nospecialize x) = !isa(x, Type) || Compiler.isdispatchelem(x)
88

99
using Random, Core.IR
@@ -1721,7 +1721,7 @@ g_test_constant() = (f_constant(3) == 3 && f_constant(4) == 4 ? true : "BAD")
17211721
f_pure_add() = (1 + 1 == 2) ? true : "FAIL"
17221722
@test @inferred f_pure_add()
17231723

1724-
import Core: Const
1724+
using Core: Const
17251725
mutable struct ARef{T}
17261726
@atomic x::T
17271727
end
@@ -1762,7 +1762,7 @@ let getfield_tfunc(@nospecialize xs...) =
17621762
@test getfield_tfunc(ARef{Int},Const(:x),Bool,Bool) === Union{}
17631763
end
17641764

1765-
import .Compiler: Const
1765+
using Core: Const
17661766
mutable struct XY{X,Y}
17671767
x::X
17681768
y::Y
@@ -2765,10 +2765,10 @@ end |> only === Int
27652765

27662766
# `apply_type_tfunc` accuracy for constrained type construction
27672767
# https://github.com/JuliaLang/julia/issues/47089
2768-
import Core: Const
2769-
import .Compiler: apply_type_tfunc
27702768
struct Issue47089{A<:Number,B<:Number} end
2771-
let 𝕃 = Compiler.fallback_lattice
2769+
let apply_type_tfunc = Compiler.apply_type_tfunc
2770+
𝕃 = Compiler.fallback_lattice
2771+
Const = Core.Const
27722772
A = Type{<:Integer}
27732773
@test apply_type_tfunc(𝕃, Const(Issue47089), A, A) <: (Type{Issue47089{A,B}} where {A<:Integer, B<:Integer})
27742774
@test apply_type_tfunc(𝕃, Const(Issue47089), Const(Int), Const(Int), Const(Int)) === Union{}
@@ -4554,7 +4554,8 @@ end |> only == Tuple{Int,Int}
45544554
end |> only == Int
45554555

45564556
# form PartialStruct for mutables with `const` field
4557-
import .Compiler: Const,
4557+
using Core: Const
4558+
using .Compiler:
45584559
mutable struct PartialMutable{S,T}
45594560
const s::S
45604561
t::T
@@ -5700,7 +5701,8 @@ let x = 1, _Any = Any
57005701
end
57015702

57025703
# Issue #51927
5703-
let 𝕃 = Compiler.fallback_lattice
5704+
let apply_type_tfunc = Compiler.apply_type_tfunc
5705+
𝕃 = Compiler.fallback_lattice
57045706
@test apply_type_tfunc(𝕃, Const(Tuple{Vararg{Any,N}} where N), Int) == Type{NTuple{_A, Any}} where _A
57055707
end
57065708

@@ -6074,6 +6076,29 @@ function issue56387(nt::NamedTuple, field::Symbol=:a)
60746076
end
60756077
@test Base.infer_return_type(issue56387, (typeof((;a=1)),)) == Type{Int}
60766078

6079+
# `apply_type_tfunc` with `Union` in its arguments
6080+
let apply_type_tfunc = Compiler.apply_type_tfunc
6081+
𝕃 = Compiler.fallback_lattice
6082+
Const = Core.Const
6083+
@test apply_type_tfunc(𝕃, Any[Const(Vector), Union{Type{Int},Type{Nothing}}]) == Union{Type{Vector{Int}},Type{Vector{Nothing}}}
6084+
end
6085+
6086+
@test Base.infer_return_type((Bool,Int,)) do b, y
6087+
x = b ? 1 : missing
6088+
inner = y -> x + y
6089+
return inner(y)
6090+
end == Union{Int,Missing}
6091+
6092+
function issue31909(ys)
6093+
x = if @noinline rand(Bool)
6094+
1
6095+
else
6096+
missing
6097+
end
6098+
map(y -> x + y, ys)
6099+
end
6100+
@test Base.infer_return_type(issue31909, (Vector{Int},)) == Union{Vector{Int},Vector{Missing}}
6101+
60776102
global setglobal!_refine::Int
60786103
@test Base.infer_return_type((Integer,)) do x
60796104
setglobal!(@__MODULE__, :setglobal!_refine, x)
@@ -6098,4 +6123,4 @@ function func_swapglobal!_must_throw(x)
60986123
swapglobal!(@__MODULE__, :swapglobal!_must_throw, x)
60996124
end
61006125
@test Base.infer_return_type(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) === Union{}
6101-
@test !Base.Compiler.is_effect_free(Base.infer_effects(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) )
6126+
@test !Compiler.is_effect_free(Base.infer_effects(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) )

0 commit comments

Comments
 (0)