diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7dc92a393f..cf30799ccd 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -89,6 +89,8 @@ jobs: test_group: [ {test_type: 'ext', label: 'differentiation_interface'}, {test_type: 'ext', label: 'differentiation_interface_second_order'}, + {test_type: 'ext', label: 'differentiation_interface_second_order_hvp'}, + {test_type: 'ext', label: 'differentiation_interface_second_order_derivative'}, {test_type: 'ext', label: 'dynamic_expressions'}, {test_type: 'ext', label: 'flux'}, {test_type: 'ext', label: 'function_wrappers'}, diff --git a/src/rules/misty_closures.jl b/src/rules/misty_closures.jl index 20a3b9e004..a9f5708130 100644 --- a/src/rules/misty_closures.jl +++ b/src/rules/misty_closures.jl @@ -32,9 +32,13 @@ end # reject our intentionally-older interpreter. # function _dual_mc(p::MistyClosure) - mc_world = UInt(p.oc.world) - interp = MooncakeInterpreter(DefaultCtx, ForwardMode; world=mc_world) - return build_frule(interp, p; skip_world_age_check=true) + return @static if VERSION ≤ v"1.12-" + build_frule(get_interpreter(ForwardMode), p) + else + mc_world = UInt(p.oc.world) + interp = MooncakeInterpreter(DefaultCtx, ForwardMode; world=mc_world) + build_frule(interp, p; skip_world_age_check=true) + end end tangent_type(::Type{<:MistyClosure}) = MistyClosureTangent diff --git a/test/ext/differentiation_interface_second_order_derivative/Project.toml b/test/ext/differentiation_interface_second_order_derivative/Project.toml new file mode 100644 index 0000000000..7639dd345f --- /dev/null +++ b/test/ext/differentiation_interface_second_order_derivative/Project.toml @@ -0,0 +1,5 @@ +[deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/ext/differentiation_interface_second_order_derivative/differentiation_interface_second_order_derivative.jl b/test/ext/differentiation_interface_second_order_derivative/differentiation_interface_second_order_derivative.jl new file mode 100644 index 0000000000..02fd7165b7 --- /dev/null +++ b/test/ext/differentiation_interface_second_order_derivative/differentiation_interface_second_order_derivative.jl @@ -0,0 +1,17 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) + +using DifferentiationInterface, DifferentiationInterfaceTest +using Mooncake: Mooncake + +function DifferentiationInterface.inner_preparation_behavior(::AutoMooncakeForward) + DifferentiationInterface.PrepareInnerSimple() +end + +# Test second-order differentiation (forward-over-reverse) +test_differentiation( + [SecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))]; + excluded=[FIRST_ORDER..., :hvp, :hessian], # testing only :second_derivative + logging=true, +) diff --git a/test/ext/differentiation_interface_second_order_hvp/Project.toml b/test/ext/differentiation_interface_second_order_hvp/Project.toml new file mode 100644 index 0000000000..7639dd345f --- /dev/null +++ b/test/ext/differentiation_interface_second_order_hvp/Project.toml @@ -0,0 +1,5 @@ +[deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/ext/differentiation_interface_second_order_hvp/differentiation_interface_second_order_hvp.jl b/test/ext/differentiation_interface_second_order_hvp/differentiation_interface_second_order_hvp.jl new file mode 100644 index 0000000000..4f03869fab --- /dev/null +++ b/test/ext/differentiation_interface_second_order_hvp/differentiation_interface_second_order_hvp.jl @@ -0,0 +1,17 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) + +using DifferentiationInterface, DifferentiationInterfaceTest +using Mooncake: Mooncake + +function DifferentiationInterface.inner_preparation_behavior(::AutoMooncakeForward) + DifferentiationInterface.PrepareInnerSimple() +end + +# Test second-order differentiation (forward-over-reverse) +test_differentiation( + [SecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))]; + excluded=[FIRST_ORDER..., :hessian, :second_derivative], # testing only hvp + logging=true, +)