diff --git a/Project.toml b/Project.toml index db72e8a83..a80c730a5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15.9" +version = "0.15.10" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/ext/BijectorsMooncakeExt.jl b/ext/BijectorsMooncakeExt.jl index 28c7e8506..bea5c2433 100644 --- a/ext/BijectorsMooncakeExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -15,7 +15,33 @@ using Bijectors: find_alpha, ChainRulesCore # unusual Integer type is encountered. @is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) -# TODO: This needs a corresponding frule!! as well for it to work on forward-mode Mooncake. +function Mooncake.frule!!( + ::Mooncake.Dual{typeof(find_alpha)}, + x::Mooncake.Dual{P}, + y::Mooncake.Dual{P}, + z::Mooncake.Dual{I}, +) where {P<:Base.IEEEFloat,I<:Integer} + # Require that the integer is non-differentiable. + if tangent_type(I) != Mooncake.NoTangent + msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent." + throw(ArgumentError(msg)) + end + # Convert Mooncake.NoTangent to ChainRulesCore.NoTangent for the integer argument + out, tangent_out = ChainRulesCore.frule( + ( + ChainRulesCore.NoTangent(), + Mooncake.tangent(x), + Mooncake.tangent(y), + ChainRulesCore.NoTangent(), + ), + find_alpha, + Mooncake.primal(x), + Mooncake.primal(y), + Mooncake.primal(z), + ) + return Mooncake.Dual(out, tangent_out) +end + function Mooncake.rrule!!( ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} ) where {P<:Base.IEEEFloat,I<:Integer} diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index 5c9385023..aafffc31e 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -31,8 +31,7 @@ end if @isdefined Mooncake rng = Xoshiro(123456) - # TODO: Enable Mooncake.ForwardMode as well. - @testset "$mode" for mode in (Mooncake.ReverseMode,) + @testset "$mode" for mode in (Mooncake.ReverseMode, Mooncake.ForwardMode) Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha,