diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 9d4e266024..a89e25f686 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -95,6 +95,7 @@ jobs: {test_type: 'integration_testing', label: 'array'}, {test_type: 'integration_testing', label: 'bijectors'}, {test_type: 'integration_testing', label: 'diff_tests'}, + {test_type: 'integration_testing', label: 'diffeq'}, {test_type: 'integration_testing', label: 'dispatch_doctor'}, {test_type: 'integration_testing', label: 'distributions'}, {test_type: 'integration_testing', label: 'dynamicppl'}, diff --git a/test/integration_testing/diffeq/Project.toml b/test/integration_testing/diffeq/Project.toml new file mode 100644 index 0000000000..98e015ff48 --- /dev/null +++ b/test/integration_testing/diffeq/Project.toml @@ -0,0 +1,6 @@ +[deps] +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" \ No newline at end of file diff --git a/test/integration_testing/diffeq/diffeq.jl b/test/integration_testing/diffeq/diffeq.jl new file mode 100644 index 0000000000..00ae9e1081 --- /dev/null +++ b/test/integration_testing/diffeq/diffeq.jl @@ -0,0 +1,53 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) + +using OrdinaryDiffEq, SciMLSensitivity, Mooncake, StableRNGs, Test +using Mooncake.TestUtils: test_rule + +# Helper function for Mooncake gradient computation +mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2] + +# Define the ODE function from the original issue +odef(du, u, p, t) = du .= u .* p + +# Define the sensitivity loss function from the original issue +struct senseloss0{T} + sense::T +end + +function (f::senseloss0)(u0p) + prob = ODEProblem{true}(odef, u0p[1:1], (0.0, 1.0), u0p[2:2]) + sum(solve(prob, Tsit5(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1)) +end + +@testset "diffeq" begin + rng = StableRNG(123456) + + # Test parameters from the original issue + u0p = [2.0, 3.0] + + @testset "senseloss0 with InterpolatingAdjoint" begin + sense_func = senseloss0(InterpolatingAdjoint()) + + # First test that the function works + @testset "Function evaluation" begin + result = sense_func(u0p) + @test result isa Real + @test isfinite(result) + end + + # Test Mooncake gradient computation + @testset "mooncake_gradient computation" begin + dup_mc = mooncake_gradient(sense_func, u0p) + @test dup_mc isa Vector + @test length(dup_mc) == 2 + @test all(isfinite, dup_mc) + end + + # Test with Mooncake's test_rule + @testset "test_rule evaluation" begin + test_rule(rng, sense_func, u0p; is_primitive=false, unsafe_perturb=true) + end + end +end \ No newline at end of file