From 81e8a89d22b86b40c22146df162e342ab41413fa Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 26 Aug 2025 12:57:07 +0100 Subject: [PATCH 1/4] Fix + tests --- src/interpreter/reverse_mode.jl | 2 +- test/interpreter/reverse_mode.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/interpreter/reverse_mode.jl b/src/interpreter/reverse_mode.jl index af989bc248..dcf6c11460 100644 --- a/src/interpreter/reverse_mode.jl +++ b/src/interpreter/reverse_mode.jl @@ -1746,7 +1746,7 @@ DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any,Any}(), debug _copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any,Any}(), x.debug_mode) function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N} - sig = Tuple{map(_typeof ∘ primal, args)...} + sig = Tuple{map(typeof ∘ primal, args)...} rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing interp = get_interpreter(ReverseMode) diff --git a/test/interpreter/reverse_mode.jl b/test/interpreter/reverse_mode.jl index 8b930f2a4b..b1986ffdab 100644 --- a/test/interpreter/reverse_mode.jl +++ b/test/interpreter/reverse_mode.jl @@ -14,6 +14,9 @@ f(a, x) = dot(a.data, x) unstable_tester(x::Ref{Any}) = sin(x[]) +# used for regression test for issue 660 +struct MakeAUnionAll{T} end + end @testset "s2s_reverse_mode_ad" begin @@ -353,6 +356,11 @@ end f(x) = sin(cos(x)) rule = Mooncake.build_rrule(f, 0.0) @benchmark Mooncake.value_and_gradient!!($rule, $f, $(Ref(0.0))[]) + + # 660 -- ensure that the correct signature is used to construct DynamicDerivedRules + rule = Mooncake.DynamicDerivedRule(false) + args = (zero_fcodual(identity), zero_fcodual((v=S2SGlobals.MakeAUnionAll,))) + @test rule(args...) isa Tuple{CoDual,Any} end @testset "literal Strings do not appear in shared data" begin f() = "hello" From 00150558b7df9f98a9a330f2e83798cd35d861bb Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 26 Aug 2025 12:59:40 +0100 Subject: [PATCH 2/4] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3c5b0db0cb..59334e079e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.150" +version = "0.4.151" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From cc6c590ec8257a6c2aa1f5b2dd78024d5fef9e16 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 26 Aug 2025 16:55:19 +0100 Subject: [PATCH 3/4] Use _stable_typeof directly --- src/interpreter/reverse_mode.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interpreter/reverse_mode.jl b/src/interpreter/reverse_mode.jl index dcf6c11460..7d297515e3 100644 --- a/src/interpreter/reverse_mode.jl +++ b/src/interpreter/reverse_mode.jl @@ -1746,7 +1746,7 @@ DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any,Any}(), debug _copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any,Any}(), x.debug_mode) function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N} - sig = Tuple{map(typeof ∘ primal, args)...} + sig = Tuple{map(Base._stable_typeof ∘ primal, args)...} rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing interp = get_interpreter(ReverseMode) From 928e95a24c15daafb25e3ede85c1d7c9513e18bd Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Wed, 27 Aug 2025 08:21:28 +0100 Subject: [PATCH 4/4] Comment on the use of _stable_typeof --- src/interpreter/reverse_mode.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/interpreter/reverse_mode.jl b/src/interpreter/reverse_mode.jl index 7d297515e3..bb7fb941f6 100644 --- a/src/interpreter/reverse_mode.jl +++ b/src/interpreter/reverse_mode.jl @@ -1746,7 +1746,14 @@ DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any,Any}(), debug _copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any,Any}(), x.debug_mode) function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N} + + # `Base._stable_typeof` is used here, rather than `typeof` or `Mooncake._typeof`. Its + # precise behaviour (equivalent to `typeof` for everything except `Type`s, for which it + # returns `Type{P}` rather than `typeof(P)`) is needed to ensure that this signature + # matches the types that `rule` sees when `rule(args...)` is called below. If you get + # this wrong, an assertion is violated, causing a hard-to-debug error (see issue 660). sig = Tuple{map(Base._stable_typeof ∘ primal, args)...} + rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing interp = get_interpreter(ReverseMode)