diff --git a/Project.toml b/Project.toml index 3c5b0db0cb..59334e079e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.150" +version = "0.4.151" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/reverse_mode.jl b/src/interpreter/reverse_mode.jl index af989bc248..bb7fb941f6 100644 --- a/src/interpreter/reverse_mode.jl +++ b/src/interpreter/reverse_mode.jl @@ -1746,7 +1746,14 @@ DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any,Any}(), debug _copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any,Any}(), x.debug_mode) function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N} - sig = Tuple{map(_typeof ∘ primal, args)...} + + # `Base._stable_typeof` is used here, rather than `typeof` or `Mooncake._typeof`. Its + # precise behaviour (equivalent to `typeof` for everything except `Type`s, for which it + # returns `Type{P}` rather than `typeof(P)`) is needed to ensure that this signature + # matches the types that `rule` sees when `rule(args...)` is called below. If you get + # this wrong, an assertion is violated, causing a hard-to-debug error (see issue 660). + sig = Tuple{map(Base._stable_typeof ∘ primal, args)...} + rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing interp = get_interpreter(ReverseMode) diff --git a/test/interpreter/reverse_mode.jl b/test/interpreter/reverse_mode.jl index 8b930f2a4b..b1986ffdab 100644 --- a/test/interpreter/reverse_mode.jl +++ b/test/interpreter/reverse_mode.jl @@ -14,6 +14,9 @@ f(a, x) = dot(a.data, x) unstable_tester(x::Ref{Any}) = sin(x[]) +# used for regression test for issue 660 +struct MakeAUnionAll{T} end + end @testset "s2s_reverse_mode_ad" begin @@ -353,6 +356,11 @@ end f(x) = sin(cos(x)) rule = Mooncake.build_rrule(f, 0.0) @benchmark Mooncake.value_and_gradient!!($rule, $f, $(Ref(0.0))[]) + + # 660 -- ensure that the correct signature is used to construct DynamicDerivedRules + rule = Mooncake.DynamicDerivedRule(false) + args = (zero_fcodual(identity), zero_fcodual((v=S2SGlobals.MakeAUnionAll,))) + @test rule(args...) isa Tuple{CoDual,Any} end @testset "literal Strings do not appear in shared data" begin f() = "hello"