diff --git a/test/precompile.jl b/test/precompile.jl index fe77ed5..f803f87 100644 --- a/test/precompile.jl +++ b/test/precompile.jl @@ -154,3 +154,213 @@ precompile_test_harness("Inference caching") do load_path @test ExampleCompiler.emit_code_count[] == 0 broken=VERSION>=v"1.12.0-DEV.1268" end end + +# Test Base.precompile(ci::CodeInstance) for persisting non-native external CIs +@static if hasmethod(Base.precompile, Tuple{Core.CodeInstance}) +precompile_test_harness("Base.precompile(ci::CodeInstance) external CI caching") do load_path + write(joinpath(load_path, "ExampleCompilerPC.jl"), :(module ExampleCompilerPC + using CompilerCaching + + const CC = Core.Compiler + + mutable struct CompileResults + ir::Any + code::Any + executable::Any + CompileResults() = new(nothing, nothing, nothing) + end + + struct ExampleInterpreter <: CC.AbstractInterpreter + world::UInt + cache::CacheView + inf_cache::Vector{CC.InferenceResult} + end + ExampleInterpreter(cache::CacheView) = + ExampleInterpreter(cache.world, cache, CC.InferenceResult[]) + + CC.InferenceParams(::ExampleInterpreter) = CC.InferenceParams() + CC.OptimizationParams(::ExampleInterpreter) = CC.OptimizationParams() + CC.get_inference_cache(interp::ExampleInterpreter) = interp.inf_cache + @static if isdefined(Core.Compiler, :get_inference_world) + Core.Compiler.get_inference_world(interp::ExampleInterpreter) = interp.world + else + Core.Compiler.get_world_counter(interp::ExampleInterpreter) = interp.world + end + CC.lock_mi_inference(::ExampleInterpreter, ::Core.MethodInstance) = nothing + CC.unlock_mi_inference(::ExampleInterpreter, ::Core.MethodInstance) = nothing + @setup_caching ExampleInterpreter.cache + + emit_code_count = Ref(0) + + function emit_ir(cache, mi) + interp = ExampleInterpreter(cache) + typeinf!(cache, interp, mi) + end + + function emit_code(cache, mi, ir) + emit_code_count[] += 1 + :code_result + end + + emit_executable(cache, mi, code) = code + + function precompile(f, tt) + world = Base.get_world_counter() + mi = method_instance(f, tt; world) + cache = CacheView{CompileResults}(:ExampleCompilerPC, world) + + ci = get(cache, mi, nothing) + if ci !== nothing + res = results(cache, ci) + if res.executable !== nothing + return res.executable + end + end + + ir = emit_ir(cache, mi) + + ci = get(cache, mi, nothing) + @assert ci !== nothing "typeinf! should have created CI" + res = results(cache, ci) + res.ir = ir + + res.code = emit_code(cache, mi, res.ir) + res.executable = emit_executable(cache, mi, res.code) + + # Register CI for precompilation caching so non-native external CIs + # survive serialization (bypasses the owner filter in queue_external_cis) + Base.precompile(ci) + + @assert res.executable === :code_result + return res.executable + end + + end # module + ) |> string) + Base.compilecache(Base.PkgId("ExampleCompilerPC"), stderr, stdout) + + write(joinpath(load_path, "ExampleUserPC.jl"), :(module ExampleUserPC + import ExampleCompilerPC + + function square(x) + return x * x + end + + # Internal method + ExampleCompilerPC.precompile(square, (Float64,)) + + # External method (identity from Base) + ExampleCompilerPC.precompile(identity, (Int,)) + + end # module + ) |> string) + Base.compilecache(Base.PkgId("ExampleUserPC"), stderr, stdout) + + @eval let + using CompilerCaching + import ExampleCompilerPC + @test ExampleCompilerPC.emit_code_count[] == 0 + + cache = CacheView{ExampleCompilerPC.CompileResults}(:ExampleCompilerPC, Base.get_world_counter()) + + identity_mi = method_instance(identity, (Int,)) + @test !haskey(cache, identity_mi) + + using ExampleUserPC + @test ExampleCompilerPC.emit_code_count[] == 0 + + cache = CacheView{ExampleCompilerPC.CompileResults}(:ExampleCompilerPC, Base.get_world_counter()) + + # Internal method CI survived (baseline) + square_mi = method_instance(ExampleUserPC.square, (Float64,)) + @test haskey(cache, square_mi) + + # External method CI survived (enabled by Base.precompile(ci)) + @test haskey(cache, identity_mi) + + # No recompilation needed + ExampleCompilerPC.precompile(ExampleUserPC.square, (Float64,)) + @test ExampleCompilerPC.emit_code_count[] == 0 + ExampleCompilerPC.precompile(identity, (Int,)) + @test ExampleCompilerPC.emit_code_count[] == 0 + + # Results integrity + ci = get(cache, identity_mi) + res = results(cache, ci) + @test res.executable === :code_result + end +end +# Test foreign IR precompilation (non-CodeInfo source in method.source) +precompile_test_harness("Foreign IR precompilation") do load_path + write(joinpath(load_path, "ForeignIRPkg.jl"), :(module ForeignIRPkg + using CompilerCaching + + const CC = Core.Compiler + + Base.Experimental.@MethodTable foreign_mt + + mutable struct ForeignResults + val::Any + ForeignResults() = new(nothing) + end + + # Define a foreign method with Symbol IR (not CodeInfo or String) + function my_kernel end + add_method(foreign_mt, my_kernel, (), :my_kernel_ir) + + function my_other_kernel end + add_method(foreign_mt, my_other_kernel, (Float64,), :my_other_kernel_ir) + + # Compile and cache during precompilation + const world = Base.get_world_counter() + const cache = CacheView{ForeignResults}(:ForeignIRPkg, world) + + function do_compile(f, tt) + mi = method_instance(f, tt; world, method_table=foreign_mt) + ci = create_ci(cache, mi) + cache[mi] = ci + results(cache, ci).val = 42 + Base.precompile(ci) + return ci + end + + const ci1 = do_compile(my_kernel, ()) + const ci2 = do_compile(my_other_kernel, (Float64,)) + + end # module + ) |> string) + Base.compilecache(Base.PkgId("ForeignIRPkg"), stderr, stdout) + + @eval let + using CompilerCaching + using ForeignIRPkg + + cache = CacheView{ForeignIRPkg.ForeignResults}(:ForeignIRPkg, Base.get_world_counter()) + + # Check that the foreign methods survived with their custom IR + mi1 = method_instance(ForeignIRPkg.my_kernel, (); + world=Base.get_world_counter(), + method_table=ForeignIRPkg.foreign_mt) + @test mi1.def.source === :my_kernel_ir + + mi2 = method_instance(ForeignIRPkg.my_other_kernel, (Float64,); + world=Base.get_world_counter(), + method_table=ForeignIRPkg.foreign_mt) + @test mi2.def.source === :my_other_kernel_ir + + # Check that CIs survived precompilation + @test haskey(cache, mi1) + @test haskey(cache, mi2) + + # Check results integrity + ci1 = get(cache, mi1) + res1 = results(cache, ci1) + @test res1.val == 42 + + ci2 = get(cache, mi2) + res2 = results(cache, ci2) + @test res2.val == 42 + end +end + +end # @static if hasmethod(Base.precompile, ...)