diff --git a/Compiler/src/tfuncs.jl b/Compiler/src/tfuncs.jl index 87dad13c50a30d..ff0143640c4629 100644 --- a/Compiler/src/tfuncs.jl +++ b/Compiler/src/tfuncs.jl @@ -1350,14 +1350,14 @@ end T = _fieldtype_tfunc(๐•ƒ, oโ€ฒ, f, isconcretetype(oโ€ฒ)) T === Bottom && return Bottom PT = Const(Pair) - return instanceof_tfunc(apply_type_tfunc(๐•ƒ, PT, T, T), true)[1] + return instanceof_tfunc(apply_type_tfunc(๐•ƒ, Any[PT, T, T]), true)[1] end @nospecs function replacefield!_tfunc(๐•ƒ::AbstractLattice, o, f, x, v, success_order=Symbol, failure_order=Symbol) oโ€ฒ = widenconst(o) T = _fieldtype_tfunc(๐•ƒ, oโ€ฒ, f, isconcretetype(oโ€ฒ)) T === Bottom && return Bottom PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T) - return instanceof_tfunc(apply_type_tfunc(๐•ƒ, PT, T), true)[1] + return instanceof_tfunc(apply_type_tfunc(๐•ƒ, Any[PT, T]), true)[1] end @nospecs function setfieldonce!_tfunc(๐•ƒ::AbstractLattice, o, f, v, success_order=Symbol, failure_order=Symbol) setfield!_tfunc(๐•ƒ, o, f, v) === Bottom && return Bottom @@ -1713,8 +1713,12 @@ end const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, :_L, :_M, :_N, :_O, :_P, :_Q, :_R, :_S, :_T, :_U, :_V, :_W, :_X, :_Y, :_Z] -# TODO: handle e.g. apply_type(T, R::Union{Type{Int32},Type{Float64}}) -@nospecs function apply_type_tfunc(๐•ƒ::AbstractLattice, headtypetype, args...) +function apply_type_tfunc(๐•ƒ::AbstractLattice, argtypes::Vector{Any}; + max_union_splitting::Int=InferenceParams().max_union_splitting) + if isempty(argtypes) + return Bottom + end + headtypetype = argtypes[1] headtypetype = widenslotwrapper(headtypetype) if isa(headtypetype, Const) headtype = headtypetype.val @@ -1723,15 +1727,15 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, else return Any end - if !isempty(args) && isvarargtype(args[end]) + largs = length(argtypes) + if largs > 1 && isvarargtype(argtypes[end]) return isvarargtype(headtype) ? TypeofVararg : Type end - largs = length(args) if headtype === Union - largs == 0 && return Const(Bottom) + largs == 1 && return Const(Bottom) hasnonType = false - for i = 1:largs - ai = args[i] + for i = 2:largs + ai = argtypes[i] if isa(ai, Const) if !isa(ai.val, Type) if isa(ai.val, TypeVar) @@ -1750,14 +1754,14 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, end end end - if largs == 1 # Union{T} --> T - return tmeet(widenconst(args[1]), Union{Type,TypeVar}) + if largs == 2 # Union{T} --> T + return tmeet(widenconst(argtypes[2]), Union{Type,TypeVar}) end hasnonType && return Type ty = Union{} allconst = true - for i = 1:largs - ai = args[i] + for i = 2:largs + ai = argtypes[i] if isType(ai) aty = ai.parameters[1] allconst &= hasuniquerep(aty) @@ -1768,6 +1772,18 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, end return allconst ? Const(ty) : Type{ty} end + if 1 < unionsplitcost(๐•ƒ, argtypes) โ‰ค max_union_splitting + โŠ” = join(๐•ƒ) + rt = Bottom + for split_argtypes = switchtupleunion(๐•ƒ, argtypes) + rt = rt โŠ” _apply_type_tfunc(๐•ƒ, headtype, split_argtypes) + end + return rt + end + return _apply_type_tfunc(๐•ƒ, headtype, argtypes) +end +@nospecs function _apply_type_tfunc(๐•ƒ::AbstractLattice, headtype, argtypes::Vector{Any}) + largs = length(argtypes) istuple = headtype === Tuple if !istuple && !isa(headtype, UnionAll) && !isvarargtype(headtype) return Union{} @@ -1781,20 +1797,20 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, # first push the tailing vars from headtype into outervars outer_start, ua = 0, headtype while isa(ua, UnionAll) - if (outer_start += 1) > largs + if (outer_start += 1) > largs - 1 push!(outervars, ua.var) end ua = ua.body end - if largs > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple + if largs - 1 > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple return Bottom # too many arguments end - outer_start = outer_start - largs + 1 + outer_start = outer_start - largs + 2 varnamectr = 1 ua = headtype - for i = 1:largs - ai = widenslotwrapper(args[i]) + for i = 2:largs + ai = widenslotwrapper(argtypes[i]) if isType(ai) aip1 = ai.parameters[1] canconst &= !has_free_typevars(aip1) @@ -1868,7 +1884,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, # If the names are known, keep the upper bound, but otherwise widen to Tuple. # This is a widening heuristic to avoid keeping type information # that's unlikely to be useful. - if !(uw.parameters[1] isa Tuple || (i == 2 && tparams[1] isa Tuple)) + if !(uw.parameters[1] isa Tuple || (i == 3 && tparams[1] isa Tuple)) ub = Any end else @@ -1910,7 +1926,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, # throwing errors. appl = headtype if isa(appl, UnionAll) - for _ = 1:largs + for _ = 2:largs appl = appl::UnionAll push!(outervars, appl.var) appl = appl.body @@ -1930,6 +1946,8 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, end return ans end +@nospecs apply_type_tfunc(๐•ƒ::AbstractLattice, headtypetype, args...) = + apply_type_tfunc(๐•ƒ, pushfirst!(collect(Any, args), headtypetype)) add_tfunc(apply_type, 1, INT_INF, apply_type_tfunc, 10) # convert the dispatch tuple type argtype to the real (concrete) type of @@ -2016,7 +2034,7 @@ end T = _memoryref_elemtype(mem) T === Bottom && return Bottom PT = Const(Pair) - return instanceof_tfunc(apply_type_tfunc(๐•ƒ, PT, T, T), true)[1] + return instanceof_tfunc(apply_type_tfunc(๐•ƒ, Any[PT, T, T]), true)[1] end @nospecs function memoryrefreplace!_tfunc(๐•ƒ::AbstractLattice, mem, x, v, success_order, failure_order, boundscheck) memoryrefset!_tfunc(๐•ƒ, mem, v, success_order, boundscheck) === Bottom && return Bottom @@ -2024,7 +2042,7 @@ end T = _memoryref_elemtype(mem) T === Bottom && return Bottom PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T) - return instanceof_tfunc(apply_type_tfunc(๐•ƒ, PT, T), true)[1] + return instanceof_tfunc(apply_type_tfunc(๐•ƒ, Any[PT, T]), true)[1] end @nospecs function memoryrefsetonce!_tfunc(๐•ƒ::AbstractLattice, mem, v, success_order, failure_order, boundscheck) memoryrefset!_tfunc(๐•ƒ, mem, v, success_order, boundscheck) === Bottom && return Bottom @@ -2668,6 +2686,8 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp end end return current_scope_tfunc(interp, sv) + elseif f === Core.apply_type + return apply_type_tfunc(๐•ƒแตข, argtypes; max_union_splitting=InferenceParams(interp).max_union_splitting) end fidx = find_tfunc(f) if fidx === nothing diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index b8c869d7375102..94e15315e4df59 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -1,7 +1,7 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license # tests for Core.Compiler correctness and precision -import Core.Compiler: Const, Conditional, โŠ‘, ReturnNode, GotoIfNot +using Core.Compiler: Conditional, โŠ‘ isdispatchelem(@nospecialize x) = !isa(x, Type) || Core.Compiler.isdispatchelem(x) using Random, Core.IR @@ -1721,7 +1721,7 @@ g_test_constant() = (f_constant(3) == 3 && f_constant(4) == 4 ? true : "BAD") f_pure_add() = (1 + 1 == 2) ? true : "FAIL" @test @inferred f_pure_add() -import Core: Const +using Core: Const mutable struct ARef{T} @atomic x::T end @@ -1762,7 +1762,7 @@ let getfield_tfunc(@nospecialize xs...) = @test getfield_tfunc(ARef{Int},Const(:x),Bool,Bool) === Union{} end -import Core.Compiler: Const +using Core: Const mutable struct XY{X,Y} x::X y::Y @@ -2767,10 +2767,10 @@ end |> only === Int # `apply_type_tfunc` accuracy for constrained type construction # https://github.com/JuliaLang/julia/issues/47089 -import Core: Const -import Core.Compiler: apply_type_tfunc struct Issue47089{A<:Number,B<:Number} end -let ๐•ƒ = Core.Compiler.fallback_lattice +let apply_type_tfunc = Core.Compiler.apply_type_tfunc + ๐•ƒ = Core.Compiler.fallback_lattice + Const = Core.Const A = Type{<:Integer} @test apply_type_tfunc(๐•ƒ, Const(Issue47089), A, A) <: (Type{Issue47089{A,B}} where {A<:Integer, B<:Integer}) @test apply_type_tfunc(๐•ƒ, Const(Issue47089), Const(Int), Const(Int), Const(Int)) === Union{} @@ -4556,7 +4556,8 @@ end |> only == Tuple{Int,Int} end |> only == Int # form PartialStruct for mutables with `const` field -import Core.Compiler: Const, โŠ‘ +using Core: Const +using Core.Compiler: โŠ‘ mutable struct PartialMutable{S,T} const s::S t::T @@ -5715,7 +5716,8 @@ let x = 1, _Any = Any end # Issue #51927 -let ๐•ƒ = Core.Compiler.fallback_lattice +let apply_type_tfunc = Core.Compiler.apply_type_tfunc + ๐•ƒ = Core.Compiler.fallback_lattice @test apply_type_tfunc(๐•ƒ, Const(Tuple{Vararg{Any,N}} where N), Int) == Type{NTuple{_A, Any}} where _A end @@ -6089,6 +6091,29 @@ function issue56387(nt::NamedTuple, field::Symbol=:a) end @test Base.infer_return_type(issue56387, (typeof((;a=1)),)) == Type{Int} +# `apply_type_tfunc` with `Union` in its arguments +let apply_type_tfunc = Base.Compiler.apply_type_tfunc + ๐•ƒ = Base.Compiler.fallback_lattice + Const = Core.Const + @test apply_type_tfunc(๐•ƒ, Any[Const(Vector), Union{Type{Int},Type{Nothing}}]) == Union{Type{Vector{Int}},Type{Vector{Nothing}}} +end + +@test Base.infer_return_type((Bool,Int,)) do b, y + x = b ? 1 : missing + inner = y -> x + y + return inner(y) +end == Union{Int,Missing} + +function issue31909(ys) + x = if @noinline rand(Bool) + 1 + else + missing + end + map(y -> x + y, ys) +end +@test Base.infer_return_type(issue31909, (Vector{Int},)) == Union{Vector{Int},Vector{Missing}} + global setglobal!_refine::Int @test Base.infer_return_type((Integer,)) do x setglobal!(@__MODULE__, :setglobal!_refine, x)