Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reflection: add Base.infer_return_type utility #52247

Merged
merged 1 commit into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,11 @@ function typeinf_type(interp::AbstractInterpreter, method::Method, @nospecialize
if contains_is(unwrap_unionall(atype).parameters, Union{})
return Union{} # don't ask: it does weird and unnecessary things, if it occurs during bootstrap
end
mi = specialize_method(method, atype, sparams)::MethodInstance
return typeinf_type(interp, specialize_method(method, atype, sparams))
end
typeinf_type(interp::AbstractInterpreter, match::MethodMatch) =
typeinf_type(interp, specialize_method(match))
function typeinf_type(interp::AbstractInterpreter, mi::MethodInstance)
start_time = ccall(:jl_typeinf_timing_begin, UInt64, ())
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance
Expand Down Expand Up @@ -1120,8 +1124,7 @@ function _return_type(interp::AbstractInterpreter, t::DataType)
rt = widenconst(rt)
else
for match in _methods_by_ftype(t, -1, get_world_counter(interp))::Vector
match = match::MethodMatch
ty = typeinf_type(interp, match.method, match.spec_types, match.sparams)
ty = typeinf_type(interp, match::MethodMatch)
ty === nothing && return Any
rt = tmerge(rt, ty)
rt === Any && break
Expand Down
85 changes: 79 additions & 6 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1710,6 +1710,8 @@ check_generated_context(world::UInt) =
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")

# TODO rename `Base.return_types` to `Base.infer_return_types`

"""
Base.return_types(
f, types=default_tt(f);
Expand Down Expand Up @@ -1741,9 +1743,9 @@ julia> Base.return_types(sum, Tuple{Vector{Int}})
julia> methods(sum, (Union{Vector{Int},UnitRange{Int}},))
# 2 methods for generic function "sum" from Base:
[1] sum(r::AbstractRange{<:Real})
@ range.jl:1396
@ range.jl:1399
[2] sum(a::AbstractArray; dims, kw...)
@ reducedim.jl:996
@ reducedim.jl:1010

julia> Base.return_types(sum, (Union{Vector{Int},UnitRange{Int}},))
2-element Vector{Any}:
Expand Down Expand Up @@ -1771,13 +1773,84 @@ function return_types(@nospecialize(f), @nospecialize(types=default_tt(f));
tt = signature_type(f, types)
matches = _methods_by_ftype(tt, #=lim=#-1, world)::Vector
for match in matches
match = match::Core.MethodMatch
ty = Core.Compiler.typeinf_type(interp, match.method, match.spec_types, match.sparams)
ty = Core.Compiler.typeinf_type(interp, match::Core.MethodMatch)
push!(rts, something(ty, Any))
end
return rts
end

"""
Base.infer_return_type(
f, types=default_tt(f);
world::UInt=get_world_counter(),
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world)) -> rt::Type

Returns an inferred return type of the function call specified by `f` and `types`.

# Arguments
- `f`: The function to analyze.
- `types` (optional): The argument types of the function. Defaults to the default tuple type of `f`.
- `world` (optional): The world counter to use for the analysis. Defaults to the current world counter.
- `interp` (optional): The abstract interpreter to use for the analysis. Defaults to a new `Core.Compiler.NativeInterpreter` with the specified `world`.

# Returns
- `rt::Type`: An inferred return type of the function call specified by the given call signature.

!!! note
Note that, different from [`Base.return_types`](@ref), this doesn't give you the list
return types of every possible method matching with the given `f` and `types`.
It returns a single return type, taking into account all potential outcomes of
any function call entailed by the given signature type.

# Example

```julia
julia> checksym(::Symbol) = :symbol;

julia> checksym(x::Any) = x;

julia> Base.infer_return_type(checksym, (Union{Symbol,String},))
Union{String, Symbol}

julia> Base.return_types(checksym, (Union{Symbol,String},))
2-element Vector{Any}:
Symbol
Union{String, Symbol}
```

It's important to note the difference here: `Base.return_types` gives back inferred results
for each method that matches the given signature `checksum(::Union{Symbol,String})`.
On the other hand `Base.infer_return_type` returns one collective result that sums up all those possibilities.

!!! warning
The `Base.infer_return_type` function should not be used from generated functions;
doing so will result in an error.
"""
function infer_return_type(@nospecialize(f), @nospecialize(types=default_tt(f));
world::UInt=get_world_counter(),
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world))
check_generated_context(world)
if isa(f, Core.OpaqueClosure)
return last(only(code_typed_opaque_closure(f)))
end
if isa(f, Core.Builtin)
return _builtin_return_type(interp, f, types)
end
tt = signature_type(f, types)
matches = Core.Compiler.findall(tt, Core.Compiler.method_table(interp))
if matches === nothing
# unanalyzable call, i.e. the interpreter world might be newer than the world where
# the `f` is defined, return the unknown return type
return Any
end
rt = Union{}
for match in matches.matches
ty = Core.Compiler.typeinf_type(interp, match::Core.MethodMatch)
rt = Core.Compiler.tmerge(rt, something(ty, Any))
end
return rt
end

"""
Base.infer_exception_types(
f, types=default_tt(f);
Expand Down Expand Up @@ -1880,7 +1953,7 @@ Returns the type of exception potentially thrown by the function call specified
!!! note
Note that, different from [`Base.infer_exception_types`](@ref), this doesn't give you the list
exception types for every possible matching method with the given `f` and `types`.
It provides a single exception type, taking into account all potential outcomes of
It returns a single exception type, taking into account all potential outcomes of
any function call entailed by the given signature type.

# Example
Expand Down Expand Up @@ -1964,7 +2037,7 @@ Returns the possible computation effects of the function call specified by `f` a
!!! note
Note that, different from [`Base.return_types`](@ref), this doesn't give you the list
effect analysis results for every possible matching method with the given `f` and `types`.
It provides a single effect, taking into account all potential outcomes of any function
It returns a single effect, taking into account all potential outcomes of any function
call entailed by the given signature type.

# Example
Expand Down
23 changes: 20 additions & 3 deletions test/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,22 @@ ambig_effects_test(a::Int, b) = 1
ambig_effects_test(a, b::Int) = 1
ambig_effects_test(a, b) = 1

@testset "infer_effects" begin
@testset "Base.infer_return_type[s]" begin
# generic function case
@test only(Base.return_types(issue41694, (Int,))) == Base.infer_return_type(issue41694, (Int,)) == Int
# case when it's not fully covered
@test only(Base.return_types(issue41694, (Integer,))) == Base.infer_return_type(issue41694, (Integer,)) == Int
# MethodError case
@test isempty(Base.return_types(issue41694, (Float64,)))
@test Base.infer_return_type(issue41694, (Float64,)) == Union{}
# builtin case
@test only(Base.return_types(typeof, (Any,))) == Base.infer_return_type(typeof, (Any,)) == DataType
@test only(Base.return_types(===, (Any,Any))) == Base.infer_return_type(===, (Any,Any)) == Bool
@test only(Base.return_types(setfield!, ())) == Base.infer_return_type(setfield!, ()) == Union{}
@test only(Base.return_types(Core.Intrinsics.mul_int, ())) == Base.infer_return_type(Core.Intrinsics.mul_int, ()) == Union{}
end

@testset "Base.infer_effects" begin
# generic functions
@test Base.infer_effects(issue41694, (Int,)) |> Core.Compiler.is_terminates
@test Base.infer_effects((Int,)) do x
Expand All @@ -1047,7 +1062,7 @@ ambig_effects_test(a, b) = 1
@test (Base.infer_effects(Core.Intrinsics.mul_int, ()); true) # `intrinsic_effects` shouldn't throw on empty `argtypes`
end

@testset "infer_exception_type[s]" begin
@testset "Base.infer_exception_type[s]" begin
# generic functions
@test Base.infer_exception_type(issue41694, (Int,)) == only(Base.infer_exception_types(issue41694, (Int,))) == ErrorException
@test Base.infer_exception_type((Int,)) do x
Expand Down Expand Up @@ -1119,7 +1134,9 @@ end
return :(x)
end
end
@test only(Base.return_types(generated_only_simple, (Real,))) == Core.Compiler.return_type(generated_only_simple, Tuple{Real}) == Any
@test only(Base.return_types(generated_only_simple, (Real,))) ==
Base.infer_return_type(generated_only_simple, (Real,)) ==
Core.Compiler.return_type(generated_only_simple, Tuple{Real}) == Any
let (src, rt) = only(code_typed(generated_only_simple, (Real,)))
@test src isa Method
@test rt == Any
Expand Down