Skip to content

Commit

Permalink
REPLCompletions: replace get_type by the proper inference
Browse files Browse the repository at this point in the history
This PR generalizes the idea from #49199 and uses inference to analyze
the types of REPL expression. This approach offers several advantages
over the current `get_[value|type]`-based implementation:
- The need for various special cases is eliminated, as lowering normalizes
  expressions, and inference handles all language features.
- Constant propagation allows us to obtain accurate completions for complex
  expressions safely (see #36437).

Analysis on arbitrary REPL expressions can be done by the following steps:
- Lower a given expression
- Form a top-level `MethodInstance` from the lowered expression
- Run inference on the top-level `MethodInstance`

This PR implements `REPLInterpreter`, a custom `AbstractInterpreter` that:
- aggressively resolve global bindings to enable reasonable completions
  for lines like `Mod.a.|` (where `|` is the cursor position)
- aggressively concrete evaluates `:inconsistent` calls to provide
  reasonable completions for cases like `Ref(Some(42))[].|`
- does not optimize the inferred code, as `REPLInterpreter` is only used
  to obtain the type or constant information of given expressions

Aggressive binding resolution presents challenges for `REPLInterpreter`'s
cache validation (since #40399 hasn't been resolved yet). To avoid cache
validation issue, `REPLInterpreter` only allows aggressive binding
resolution for top-level frame representing REPL input code
(`repl_frame`) and for child `getproperty` frames that are
constant propagated from the `repl_frame`. This works, since
1.) these frames are never cached, and
2.) their results are only observed by the non-cached `repl_frame`

`REPLInterpreter` also aggressively concrete evaluate `:inconsistent`
calls within `repl_frame`, allowing it to get get accurate type
information about complex expressions that otherwise can not be constant
folded, in a safe way, i.e. it still doesn't evaluate effectful
expressions like `pop!(xs)`. Similarly to the aggressive binding
resolution, aggressive concrete evaluation doesn't present any cache
validation issues because `repl_frame` is never cached.

Also note that the code cache for `REPLInterpreter` is separated from the
native code cache, ensuring that code caches produced by `REPLInterpreter`,
where bindings are aggressively resolved and the code is not really
optimized, do not affect the native code execution. A hack has
also been added to avoid serializing `CodeInstances`s produced by
`REPLInterpreter` during precompilation to workaround #48453.

closes #36437
replaces #49199
  • Loading branch information
aviatesk committed Apr 2, 2023
1 parent a20a3d0 commit e2932cf
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 139 deletions.
291 changes: 174 additions & 117 deletions stdlib/REPL/src/REPLCompletions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ module REPLCompletions

export completions, shell_completions, bslash_completions, completion_text

using Core: CodeInfo, MethodInstance, CodeInstance, Const
const CC = Core.Compiler
using Base.Meta
using Base: propertynames, something

Expand Down Expand Up @@ -151,21 +153,21 @@ function complete_symbol(sym::String, @nospecialize(ffunc), context_module::Modu

ex = Meta.parse(lookup_name, raise=false, depwarn=false)

b, found = get_value(ex, context_module)
if found
val = b
if isa(b, Module)
mod = b
res = repl_eval_ex(ex, context_module)
res === nothing && return Completion[]
if res isa Const
val = res.val
if isa(val, Module)
mod = val
lookup_module = true
else
lookup_module = false
t = typeof(b)
t = typeof(val)
end
else # If the value is not found using get_value, the expression contain an advanced expression
else
lookup_module = false
t, found = get_type(ex, context_module)
t = CC.widenconst(res)
end
found || return Completion[]
end

suggestions = Completion[]
Expand Down Expand Up @@ -404,133 +406,182 @@ function find_start_brace(s::AbstractString; c_start='(', c_end=')')
return (startind:lastindex(s), method_name_end)
end

# Returns the value in a expression if sym is defined in current namespace fn.
# This method is used to iterate to the value of a expression like:
# :(REPL.REPLCompletions.whitespace_chars) a `dump` of this expression
# will show it consist of Expr, QuoteNode's and Symbol's which all needs to
# be handled differently to iterate down to get the value of whitespace_chars.
function get_value(sym::Expr, fn)
if sym.head === :quote || sym.head === :inert
return sym.args[1], true
end
sym.head !== :. && return (nothing, false)
for ex in sym.args
ex, found = get_value(ex, fn)::Tuple{Any, Bool}
!found && return (nothing, false)
fn, found = get_value(ex, fn)::Tuple{Any, Bool}
!found && return (nothing, false)
end
return (fn, true)
struct REPLInterpreterCache
dict::IdDict{MethodInstance,CodeInstance}
end
get_value(sym::Symbol, fn) = isdefined(fn, sym) ? (getfield(fn, sym), true) : (nothing, false)
get_value(sym::QuoteNode, fn) = (sym.value, true)
get_value(sym::GlobalRef, fn) = get_value(sym.name, sym.mod)
get_value(sym, fn) = (sym, true)

# Return the type of a getfield call expression
function get_type_getfield(ex::Expr, fn::Module)
length(ex.args) == 3 || return Any, false # should never happen, but just for safety
fld, found = get_value(ex.args[3], fn)
fld isa Symbol || return Any, false
obj = ex.args[2]
objt, found = get_type(obj, fn)
found || return Any, false
objt isa DataType || return Any, false
hasfield(objt, fld) || return Any, false
return fieldtype(objt, fld), true
REPLInterpreterCache() = REPLInterpreterCache(IdDict{MethodInstance,CodeInstance}())
const REPL_INTERPRETER_CACHE = REPLInterpreterCache()

function get_code_cache()
# XXX Avoid storing analysis results into the cache that persists across precompilation,
# as [sys|pkg]image currently doesn't support serializing externally created `CodeInstance`.
# Otherwise, `CodeInstance`s created by `REPLInterpreter``, that are much less optimized
# that those produced by `NativeInterpreter`, will leak into the native code cache,
# potentially causing runtime slowdown.
# (see https://github.com/JuliaLang/julia/issues/48453).
if (@ccall jl_generating_output()::Cint) == 1
return REPLInterpreterCache()
else
return REPL_INTERPRETER_CACHE
end
end

# Determines the return type with the Compiler of a function call using the type information of the arguments.
function get_type_call(expr::Expr, fn::Module)
f_name = expr.args[1]
f, found = get_type(f_name, fn)
found || return (Any, false) # If the function f is not found return Any.
args = Any[]
for i in 2:length(expr.args) # Find the type of the function arguments
typ, found = get_type(expr.args[i], fn)
found ? push!(args, typ) : push!(args, Any)
struct REPLInterpreter <: CC.AbstractInterpreter
repl_frame::CC.InferenceResult
world::UInt
inf_params::CC.InferenceParams
opt_params::CC.OptimizationParams
inf_cache::Vector{CC.InferenceResult}
code_cache::REPLInterpreterCache
function REPLInterpreter(repl_frame::CC.InferenceResult;
world::UInt = Base.get_world_counter(),
inf_params::CC.InferenceParams = CC.InferenceParams(),
opt_params::CC.OptimizationParams = CC.OptimizationParams(),
inf_cache::Vector{CC.InferenceResult} = CC.InferenceResult[],
code_cache::REPLInterpreterCache = get_code_cache())
return new(repl_frame, world, inf_params, opt_params, inf_cache, code_cache)
end
world = Base.get_world_counter()
return_type = Core.Compiler.return_type(Tuple{f, args...}, world)
return (return_type, true)
end

# Returns the return type. example: get_type(:(Base.strip("", ' ')), Main) returns (SubString{String}, true)
function try_get_type(sym::Expr, fn::Module)
val, found = get_value(sym, fn)
found && return Core.Typeof(val), found
if sym.head === :call
# getfield call is special cased as the evaluation of getfield provides good type information,
# is inexpensive and it is also performed in the complete_symbol function.
a1 = sym.args[1]
if a1 === :getfield || a1 === GlobalRef(Core, :getfield)
return get_type_getfield(sym, fn)
CC.InferenceParams(interp::REPLInterpreter) = interp.inf_params
CC.OptimizationParams(interp::REPLInterpreter) = interp.opt_params
CC.get_world_counter(interp::REPLInterpreter) = interp.world
CC.get_inference_cache(interp::REPLInterpreter) = interp.inf_cache
CC.code_cache(interp::REPLInterpreter) = CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
CC.get(wvc::CC.WorldView{REPLInterpreterCache}, mi::MethodInstance, default) = get(wvc.cache.dict, mi, default)
CC.getindex(wvc::CC.WorldView{REPLInterpreterCache}, mi::MethodInstance) = getindex(wvc.cache.dict, mi)
CC.haskey(wvc::CC.WorldView{REPLInterpreterCache}, mi::MethodInstance) = haskey(wvc.cache.dict, mi)
CC.setindex!(wvc::CC.WorldView{REPLInterpreterCache}, ci::CodeInstance, mi::MethodInstance) = setindex!(wvc.cache.dict, ci, mi)

# REPLInterpreter is only used for type analysis, so it should disable optimization entirely
CC.may_optimize(::REPLInterpreter) = false

# REPLInterpreter analyzes a top-level frame, so better to not bail out from it
CC.bail_out_toplevel_call(::REPLInterpreter, ::CC.InferenceLoopState, ::CC.InferenceState) = false

# `REPLInterpreter` aggressively resolves global bindings to enable reasonable completions
# for lines like `Mod.a.|` (where `|` is the cursor position).
# Aggressive binding resolution poses challenges for the inference cache validation
# (until https://github.com/JuliaLang/julia/issues/40399 is implemented).
# To avoid the cache validation issues, `REPLInterpreter` only allows aggressive binding
# resolution for top-level frame representing REPL input code (`repl_frame`) and for child
# `getproperty` frames that are constant propagated from the `repl_frame`. This works, since
# a.) these frames are never cached, and
# b.) their results are only observed by the non-cached `repl_frame`.
#
# `REPLInterpreter` also aggressively concrete evaluate `:inconsistent` calls within
# `repl_frame` to provide reasonable completions for lines like `Ref(Some(42))[].|`.
# Aggressive concrete evaluation allows us to get accurate type information about complex
# expressions that otherwise can not be constant folded, in a safe way, i.e. it still
# doesn't evaluate effectful expressions like `pop!(xs)`.
# Similarly to the aggressive binding resolution, aggressive concrete evaluation doesn't
# present any cache validation issues because `repl_frame` is never cached.

is_repl_frame(interp::REPLInterpreter, sv::CC.InferenceState) = interp.repl_frame === sv.result

# aggressive global binding resolution within `repl_frame`
function CC.abstract_eval_globalref(interp::REPLInterpreter, g::GlobalRef,
sv::CC.InferenceState)
if is_repl_frame(interp, sv)
if CC.isdefined_globalref(g)
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
end
return get_type_call(sym, fn)
elseif sym.head === :thunk
thk = sym.args[1]
rt = ccall(:jl_infer_thunk, Any, (Any, Any), thk::Core.CodeInfo, fn)
rt !== Any && return (rt, true)
elseif sym.head === :ref
# some simple cases of `expand`
return try_get_type(Expr(:call, GlobalRef(Base, :getindex), sym.args...), fn)
elseif sym.head === :. && sym.args[2] isa QuoteNode # second check catches broadcasting
return try_get_type(Expr(:call, GlobalRef(Core, :getfield), sym.args...), fn)
elseif sym.head === :toplevel || sym.head === :block
isempty(sym.args) && return (nothing, true)
return try_get_type(sym.args[end], fn)
elseif sym.head === :escape || sym.head === :var"hygienic-scope"
return try_get_type(sym.args[1], fn)
return Union{}
end
return (Any, false)
return @invoke CC.abstract_eval_globalref(interp::CC.AbstractInterpreter, g::GlobalRef,
sv::CC.InferenceState)
end

try_get_type(other, fn::Module) = get_type(other, fn)
function is_repl_frame_getproperty(interp::REPLInterpreter, sv::CC.InferenceState)
def = sv.linfo.def
def isa Method || return false
def.name === :getproperty || return false
sv.cached && return false
return is_repl_frame(interp, sv.parent)
end

function get_type(sym::Expr, fn::Module)
# try to analyze nests of calls. if this fails, try using the expanded form.
val, found = try_get_type(sym, fn)
found && return val, found
# https://github.com/JuliaLang/julia/issues/27184
if isexpr(sym, :macrocall)
_, found = get_type(first(sym.args), fn)
found || return Any, false
end
newsym = try
macroexpand(fn, sym; recursive=false)
catch e
# user code failed in macroexpand (ignore it)
return Any, false
end
val, found = try_get_type(newsym, fn)
if !found
newsym = try
Meta.lower(fn, sym)
catch e
# user code failed in lowering (ignore it)
return Any, false
# aggressive global binding resolution for `getproperty(::Module, ::Symbol)` calls within `repl_frame`
function CC.builtin_tfunction(interp::REPLInterpreter, @nospecialize(f),
argtypes::Vector{Any}, sv::CC.InferenceState)
if f === Core.getglobal && is_repl_frame_getproperty(interp, sv)
if length(argtypes) == 2
a1, a2 = argtypes
if isa(a1, Const) && isa(a2, Const)
a1val, a2val = a1.val, a2.val
if isa(a1val, Module) && isa(a2val, Symbol)
g = GlobalRef(a1val, a2val)
if CC.isdefined_globalref(g)
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
end
return Union{}
end
end
end
val, found = try_get_type(newsym, fn)
end
return val, found
return @invoke CC.builtin_tfunction(interp::CC.AbstractInterpreter, f::Any,
argtypes::Vector{Any}, sv::CC.InferenceState)
end

function get_type(sym, fn::Module)
val, found = get_value(sym, fn)
return found ? Core.Typeof(val) : Any, found
# aggressive concrete evaluation for `:inconsistent` frames within `repl_frame`
function CC.concrete_eval_eligible(interp::REPLInterpreter, @nospecialize(f),
result::CC.MethodCallResult, arginfo::CC.ArgInfo,
sv::CC.InferenceState)
if is_repl_frame(interp, sv)
neweffects = CC.Effects(result.effects; consistent=CC.ALWAYS_TRUE)
result = CC.MethodCallResult(result.rt, result.edgecycle, result.edgelimited,
result.edge, neweffects)
end
return @invoke CC.concrete_eval_eligible(interp::CC.AbstractInterpreter, f::Any,
result::CC.MethodCallResult, arginfo::CC.ArgInfo,
sv::CC.InferenceState)
end

function resolve_toplevel_symbols!(mod::Module, src::Core.CodeInfo)
newsrc = copy(src)
@ccall jl_resolve_globals_in_ir(
#=jl_array_t *stmts=# newsrc.code::Any,
#=jl_module_t *m=# mod::Any,
#=jl_svec_t *sparam_vals=# Core.svec()::Any,
#=int binding_effects=# 0::Int)::Cvoid
return newsrc
end

function get_type(T, found::Bool, default_any::Bool)
return found ? T :
default_any ? Any : throw(ArgumentError("argument not found"))
# lower `ex` and run type inference on the resulting top-level expression
function repl_eval_ex(@nospecialize(ex), context_module::Module)
lwr = try
Meta.lower(context_module, ex)
catch # macro expansion failed, etc.
return nothing
end
if lwr isa Symbol
return isdefined(context_module, lwr) ? Const(getfield(context_module, lwr)) : nothing
end
lwr isa Expr || return Const(lwr) # `ex` is literal
isexpr(lwr, :thunk) || return nothing # lowered to `Expr(:error, ...)` or similar
src = lwr.args[1]::Core.CodeInfo

# construct top-level `MethodInstance`
mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ());
mi.specTypes = Tuple{}

mi.def = context_module
src = resolve_toplevel_symbols!(context_module, src)
@atomic mi.uninferred = src

result = CC.InferenceResult(mi)
interp = REPLInterpreter(result)
frame = CC.InferenceState(result, src, #=cache=#:no, interp)::CC.InferenceState

CC.typeinf(interp, frame)

return frame.result.result
end

# Method completion on function call expression that look like :(max(1))
MAX_METHOD_COMPLETIONS::Int = 40
function _complete_methods(ex_org::Expr, context_module::Module, shift::Bool)
funct, found = get_type(ex_org.args[1], context_module)::Tuple{Any,Bool}
!found && return 2, funct, [], Set{Symbol}()

funct = repl_eval_ex(ex_org.args[1], context_module)
funct === nothing && return 2, nothing, [], Set{Symbol}()
funct = CC.widenconst(funct)
args_ex, kwargs_ex, kwargs_flag = complete_methods_args(ex_org, context_module, true, true)
return kwargs_flag, funct, args_ex, kwargs_ex
end
Expand Down Expand Up @@ -635,7 +686,14 @@ function detect_args_kwargs(funargs::Vector{Any}, context_module::Module, defaul
# argument types
push!(args_ex, Any)
else
push!(args_ex, get_type(get_type(ex, context_module)..., default_any))
argt = repl_eval_ex(ex, context_module)
if argt !== nothing
push!(args_ex, CC.widenconst(argt))
elseif default_any
push!(args_ex, Any)
else
throw(ArgumentError("argument not found"))
end
end
end
end
Expand Down Expand Up @@ -709,7 +767,6 @@ function close_path_completion(str, startpos, r, paths, pos)
return lastindex(str) <= pos || str[nextind(str, pos)] != '"'
end


function bslash_completions(string::String, pos::Int)
slashpos = something(findprev(isequal('\\'), string, pos), 0)
if (something(findprev(in(bslash_separators), string, pos), 0) < slashpos &&
Expand Down
Loading

0 comments on commit e2932cf

Please sign in to comment.