Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,213 @@ precompile_test_harness("Inference caching") do load_path
@test ExampleCompiler.emit_code_count[] == 0 broken=VERSION>=v"1.12.0-DEV.1268"
end
end

# Test Base.precompile(ci::CodeInstance) for persisting non-native external CIs
@static if hasmethod(Base.precompile, Tuple{Core.CodeInstance})
precompile_test_harness("Base.precompile(ci::CodeInstance) external CI caching") do load_path
write(joinpath(load_path, "ExampleCompilerPC.jl"), :(module ExampleCompilerPC
using CompilerCaching

const CC = Core.Compiler

mutable struct CompileResults
ir::Any
code::Any
executable::Any
CompileResults() = new(nothing, nothing, nothing)
end

struct ExampleInterpreter <: CC.AbstractInterpreter
world::UInt
cache::CacheView
inf_cache::Vector{CC.InferenceResult}
end
ExampleInterpreter(cache::CacheView) =
ExampleInterpreter(cache.world, cache, CC.InferenceResult[])

CC.InferenceParams(::ExampleInterpreter) = CC.InferenceParams()
CC.OptimizationParams(::ExampleInterpreter) = CC.OptimizationParams()
CC.get_inference_cache(interp::ExampleInterpreter) = interp.inf_cache
@static if isdefined(Core.Compiler, :get_inference_world)
Core.Compiler.get_inference_world(interp::ExampleInterpreter) = interp.world
else
Core.Compiler.get_world_counter(interp::ExampleInterpreter) = interp.world
end
CC.lock_mi_inference(::ExampleInterpreter, ::Core.MethodInstance) = nothing
CC.unlock_mi_inference(::ExampleInterpreter, ::Core.MethodInstance) = nothing
@setup_caching ExampleInterpreter.cache

emit_code_count = Ref(0)

function emit_ir(cache, mi)
interp = ExampleInterpreter(cache)
typeinf!(cache, interp, mi)
end

function emit_code(cache, mi, ir)
emit_code_count[] += 1
:code_result
end

emit_executable(cache, mi, code) = code

function precompile(f, tt)
world = Base.get_world_counter()
mi = method_instance(f, tt; world)
cache = CacheView{CompileResults}(:ExampleCompilerPC, world)

ci = get(cache, mi, nothing)
if ci !== nothing
res = results(cache, ci)
if res.executable !== nothing
return res.executable
end
end

ir = emit_ir(cache, mi)

ci = get(cache, mi, nothing)
@assert ci !== nothing "typeinf! should have created CI"
res = results(cache, ci)
res.ir = ir

res.code = emit_code(cache, mi, res.ir)
res.executable = emit_executable(cache, mi, res.code)

# Register CI for precompilation caching so non-native external CIs
# survive serialization (bypasses the owner filter in queue_external_cis)
Base.precompile(ci)

@assert res.executable === :code_result
return res.executable
end

end # module
) |> string)
Base.compilecache(Base.PkgId("ExampleCompilerPC"), stderr, stdout)

write(joinpath(load_path, "ExampleUserPC.jl"), :(module ExampleUserPC
import ExampleCompilerPC

function square(x)
return x * x
end

# Internal method
ExampleCompilerPC.precompile(square, (Float64,))

# External method (identity from Base)
ExampleCompilerPC.precompile(identity, (Int,))

end # module
) |> string)
Base.compilecache(Base.PkgId("ExampleUserPC"), stderr, stdout)

@eval let
using CompilerCaching
import ExampleCompilerPC
@test ExampleCompilerPC.emit_code_count[] == 0

cache = CacheView{ExampleCompilerPC.CompileResults}(:ExampleCompilerPC, Base.get_world_counter())

identity_mi = method_instance(identity, (Int,))
@test !haskey(cache, identity_mi)

using ExampleUserPC
@test ExampleCompilerPC.emit_code_count[] == 0

cache = CacheView{ExampleCompilerPC.CompileResults}(:ExampleCompilerPC, Base.get_world_counter())

# Internal method CI survived (baseline)
square_mi = method_instance(ExampleUserPC.square, (Float64,))
@test haskey(cache, square_mi)

# External method CI survived (enabled by Base.precompile(ci))
@test haskey(cache, identity_mi)

# No recompilation needed
ExampleCompilerPC.precompile(ExampleUserPC.square, (Float64,))
@test ExampleCompilerPC.emit_code_count[] == 0
ExampleCompilerPC.precompile(identity, (Int,))
@test ExampleCompilerPC.emit_code_count[] == 0

# Results integrity
ci = get(cache, identity_mi)
res = results(cache, ci)
@test res.executable === :code_result
end
end
# Test foreign IR precompilation (non-CodeInfo source in method.source)
precompile_test_harness("Foreign IR precompilation") do load_path
write(joinpath(load_path, "ForeignIRPkg.jl"), :(module ForeignIRPkg
using CompilerCaching

const CC = Core.Compiler

Base.Experimental.@MethodTable foreign_mt

mutable struct ForeignResults
val::Any
ForeignResults() = new(nothing)
end

# Define a foreign method with Symbol IR (not CodeInfo or String)
function my_kernel end
add_method(foreign_mt, my_kernel, (), :my_kernel_ir)

function my_other_kernel end
add_method(foreign_mt, my_other_kernel, (Float64,), :my_other_kernel_ir)

# Compile and cache during precompilation
const world = Base.get_world_counter()
const cache = CacheView{ForeignResults}(:ForeignIRPkg, world)

function do_compile(f, tt)
mi = method_instance(f, tt; world, method_table=foreign_mt)
ci = create_ci(cache, mi)
cache[mi] = ci
results(cache, ci).val = 42
Base.precompile(ci)
return ci
end

const ci1 = do_compile(my_kernel, ())
const ci2 = do_compile(my_other_kernel, (Float64,))

end # module
) |> string)
Base.compilecache(Base.PkgId("ForeignIRPkg"), stderr, stdout)

@eval let
using CompilerCaching
using ForeignIRPkg

cache = CacheView{ForeignIRPkg.ForeignResults}(:ForeignIRPkg, Base.get_world_counter())

# Check that the foreign methods survived with their custom IR
mi1 = method_instance(ForeignIRPkg.my_kernel, ();
world=Base.get_world_counter(),
method_table=ForeignIRPkg.foreign_mt)
@test mi1.def.source === :my_kernel_ir

mi2 = method_instance(ForeignIRPkg.my_other_kernel, (Float64,);
world=Base.get_world_counter(),
method_table=ForeignIRPkg.foreign_mt)
@test mi2.def.source === :my_other_kernel_ir

# Check that CIs survived precompilation
@test haskey(cache, mi1)
@test haskey(cache, mi2)

# Check results integrity
ci1 = get(cache, mi1)
res1 = results(cache, ci1)
@test res1.val == 42

ci2 = get(cache, mi2)
res2 = results(cache, ci2)
@test res2.val == 42
end
end

end # @static if hasmethod(Base.precompile, ...)
Loading