diff --git a/Project.toml b/Project.toml index 950b690d43..31deb7bf64 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" @@ -33,6 +34,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" [extensions] MooncakeAllocCheckExt = "AllocCheck" MooncakeCUDAExt = "CUDA" +MooncakeDifferentiationInterfaceExt = "DifferentiationInterface" MooncakeDynamicExpressionsExt = "DynamicExpressions" MooncakeFluxExt = "Flux" MooncakeFunctionWrappersExt = "FunctionWrappers" @@ -51,6 +53,7 @@ CUDA = "5" ChainRules = "1.71.0" ChainRulesCore = "1" DiffTests = "0.1" +DifferentiationInterface = "0.7" DispatchDoctor = "0.4.26" DynamicExpressions = "2" ExprTools = "0.1" diff --git a/ext/MooncakeDifferentiationInterfaceExt.jl b/ext/MooncakeDifferentiationInterfaceExt.jl new file mode 100644 index 0000000000..600ee38ff5 --- /dev/null +++ b/ext/MooncakeDifferentiationInterfaceExt.jl @@ -0,0 +1,147 @@ +module MooncakeDifferentiationInterfaceExt + +using Mooncake: + Mooncake, @is_primitive, MinimalCtx, ForwardMode, Dual, primal, tangent, NoTangent +import DifferentiationInterface as DI + +# Mark shuffled_gradient as forward-mode primitive to avoid expensive type inference hang. +# This prevents build_frule from trying to derive rules for the complex gradient closure. +@is_primitive MinimalCtx ForwardMode Tuple{typeof(DI.shuffled_gradient),Vararg} +@is_primitive MinimalCtx ForwardMode Tuple{typeof(DI.shuffled_gradient!),Vararg} + +# Helper to create Dual array from primal and tangent arrays +_make_dual_array(x::AbstractArray, dx::AbstractArray) = Dual.(x, dx) +_make_dual_array(x, dx) = Dual(x, dx) + +# Helper to extract primal and tangent from Dual array +_extract_primals(arr::AbstractArray{<:Dual}) = primal.(arr) +_extract_primals(d::Dual) = primal(d) +_extract_tangents(arr::AbstractArray{<:Dual}) = tangent.(arr) +_extract_tangents(d::Dual) = tangent(d) + +# frule for shuffled_gradient without prep +# shuffled_gradient(x, f, backend, rewrap, contexts...) -> gradient(f, backend, x, contexts...) +function Mooncake.frule!!( + ::Dual{typeof(DI.shuffled_gradient)}, + x_dual::Dual, + f_dual::Dual, + backend_dual::Dual, + rewrap_dual::Dual, + context_duals::Vararg{Dual}, +) + # Extract primals and tangents + x = primal(x_dual) + dx = tangent(x_dual) + f = primal(f_dual) + backend = primal(backend_dual) + rewrap = primal(rewrap_dual) + contexts = map(d -> primal(d), context_duals) + + # Create Dual inputs: each element is Dual(x[i], dx[i]) + # This allows the Hvp to be computed via forward-over-reverse + x_with_duals = _make_dual_array(x, dx) + + # Call gradient with Dual inputs + # Since Dual{Float64,Float64} is self-tangent, reverse mode handles it correctly + grad_duals = DI.shuffled_gradient(x_with_duals, f, backend, rewrap, contexts...) + + # Extract primal (gradient) and tangent (Hvp) from the Dual outputs + grad_primal = _extract_primals(grad_duals) + grad_tangent = _extract_tangents(grad_duals) + + return Dual(grad_primal, grad_tangent) +end + +# frule for shuffled_gradient with prep +function Mooncake.frule!!( + ::Dual{typeof(DI.shuffled_gradient)}, + x_dual::Dual, + f_dual::Dual, + prep_dual::Dual, + backend_dual::Dual, + rewrap_dual::Dual, + context_duals::Vararg{Dual}, +) + x = primal(x_dual) + dx = tangent(x_dual) + f = primal(f_dual) + prep = primal(prep_dual) + backend = primal(backend_dual) + rewrap = primal(rewrap_dual) + contexts = map(d -> primal(d), context_duals) + + x_with_duals = _make_dual_array(x, dx) + grad_duals = DI.shuffled_gradient(x_with_duals, f, prep, backend, rewrap, contexts...) + + grad_primal = _extract_primals(grad_duals) + grad_tangent = _extract_tangents(grad_duals) + + return Dual(grad_primal, grad_tangent) +end + +# frule for shuffled_gradient! (in-place version) +function Mooncake.frule!!( + ::Dual{typeof(DI.shuffled_gradient!)}, + grad_dual::Dual, + x_dual::Dual, + f_dual::Dual, + backend_dual::Dual, + rewrap_dual::Dual, + context_duals::Vararg{Dual}, +) + grad = primal(grad_dual) + dgrad = tangent(grad_dual) # Tangent storage for gradient (where Hvp goes) + x = primal(x_dual) + dx = tangent(x_dual) + f = primal(f_dual) + backend = primal(backend_dual) + rewrap = primal(rewrap_dual) + contexts = map(d -> primal(d), context_duals) + + x_with_duals = _make_dual_array(x, dx) + # Allocate Dual buffer for in-place gradient + grad_duals = _make_dual_array(grad, similar(grad)) + DI.shuffled_gradient!(grad_duals, x_with_duals, f, backend, rewrap, contexts...) + + # Copy primal (gradient) back to grad + grad .= _extract_primals(grad_duals) + # Copy tangent (Hvp) back to dgrad + dgrad .= _extract_tangents(grad_duals) + + return Dual(nothing, NoTangent()) +end + +# frule for shuffled_gradient! with prep +function Mooncake.frule!!( + ::Dual{typeof(DI.shuffled_gradient!)}, + grad_dual::Dual, + x_dual::Dual, + f_dual::Dual, + prep_dual::Dual, + backend_dual::Dual, + rewrap_dual::Dual, + context_duals::Vararg{Dual}, +) + grad = primal(grad_dual) + dgrad = tangent(grad_dual) # Tangent storage for gradient (where Hvp goes) + x = primal(x_dual) + dx = tangent(x_dual) + f = primal(f_dual) + prep = primal(prep_dual) + backend = primal(backend_dual) + rewrap = primal(rewrap_dual) + contexts = map(d -> primal(d), context_duals) + + x_with_duals = _make_dual_array(x, dx) + grad_duals = _make_dual_array(grad, similar(grad)) + DI.shuffled_gradient!(grad_duals, x_with_duals, f, prep, backend, rewrap, contexts...) + + # Copy primal (gradient) back to grad + grad .= _extract_primals(grad_duals) + # Copy tangent (Hvp) back to dgrad + dgrad .= _extract_tangents(grad_duals) + + return Dual(nothing, NoTangent()) +end + +end diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 87ef7903f9..b7f7318331 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -151,6 +151,7 @@ else include(joinpath("rrules", "array_legacy.jl")) end include(joinpath("rrules", "performance_patches.jl")) +include(joinpath("rrules", "dual_arithmetic.jl")) include("interface.jl") include("config.jl") diff --git a/src/dual.jl b/src/dual.jl index 65cf53532a..34e563bb08 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -55,3 +55,140 @@ verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x) function Dual(x::Type{P}, dx::NoTangent) where {P} return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(x, dx) end + +# Dual of numeric types is self-tangent +@inline tangent_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual{P,T} + +@inline zero_tangent_internal( + x::Dual{P,T}, ::MaybeCache +) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T)) + +@inline function randn_tangent_internal( + rng::AbstractRNG, x::Dual{P,T}, ::MaybeCache +) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(randn(rng, P), randn(rng, T)) +end + +@inline function increment!!(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) +end + +@inline set_to_zero_internal!!( + ::SetToZeroCache, x::Dual{P,T} +) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T)) + +@inline function increment_internal!!( + ::IncCache, x::Dual{P,T}, y::Dual{P,T} +) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) +end + +Base.one(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(one(P), zero(T)) +function Base.one(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(one(primal(x)), zero(tangent(x))) +end + +# Arithmetic operations +function Base.:+(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + y, tangent(x)) +end +function Base.:+(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x + primal(y), tangent(y)) +end +function Base.:+(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) +end +function Base.:+(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + y, tangent(x)) +end +function Base.:+(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x + primal(y), tangent(y)) +end + +# Subtraction +Base.:-(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(-primal(x), -tangent(x)) +function Base.:-(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) - y, tangent(x)) +end +function Base.:-(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x - primal(y), -tangent(y)) +end +function Base.:-(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) - primal(y), tangent(x) - tangent(y)) +end +function Base.:-(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) - y, tangent(x)) +end +function Base.:-(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x - primal(y), -tangent(y)) +end + +# Multiplication (product rule) +function Base.:*(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) * y, tangent(x) * y) +end +function Base.:*(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x * primal(y), x * tangent(y)) +end +function Base.:*(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) * primal(y), primal(x) * tangent(y) + tangent(x) * primal(y)) +end +function Base.:*(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) * y, tangent(x) * y) +end +function Base.:*(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x * primal(y), x * tangent(y)) +end + +# Division (quotient rule) +function Base.:/(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) / y, tangent(x) / y) +end +function Base.:/(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x / primal(y), -x * tangent(y) / primal(y)^2) +end +function Base.:/(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual( + primal(x) / primal(y), + (tangent(x) * primal(y) - primal(x) * tangent(y)) / primal(y)^2, + ) +end + +# Power (chain rule) +function Base.:^(x::Dual{P,T}, n::Integer) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x)^n, n * primal(x)^(n - 1) * tangent(x)) +end +function Base.:^(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x)^y, y * primal(x)^(y - 1) * tangent(x)) +end + +# Comparison (use primal for comparisons) +Base.:<(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) < y +Base.:<(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x < primal(y) +function Base.:<(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return primal(x) < primal(y) +end +Base.:>(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) > y +Base.:>(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x > primal(y) +function Base.:>(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return primal(x) > primal(y) +end +Base.:<=(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) <= y +Base.:<=(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x <= primal(y) +function Base.:<=(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return primal(x) <= primal(y) +end +Base.:>=(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) >= y +Base.:>=(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x >= primal(y) +function Base.:>=(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return primal(x) >= primal(y) +end + +# Conversion and promotion +Base.convert(::Type{Dual{P,T}}, x::P) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(x, zero(T)) +function Base.promote_rule(::Type{Dual{P,T}}, ::Type{P}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual{P,T} +end + +LinearAlgebra.transpose(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x +LinearAlgebra.adjoint(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index e5b46319d0..3d303e6902 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -159,6 +159,7 @@ fdata_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?")) @foldable fdata_type(::Type{Union{}}) = Union{} fdata_type(::Type{T}) where {T<:IEEEFloat} = NoFData +fdata_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = NoFData function fdata_type(::Type{PossiblyUninitTangent{T}}) where {T} Tfields = fdata_type(T) @@ -437,6 +438,7 @@ rdata_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?")) @foldable rdata_type(::Type{Union{}}) = Union{} rdata_type(::Type{T}) where {T<:IEEEFloat} = T +rdata_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual{P,T} function rdata_type(::Type{PossiblyUninitTangent{T}}) where {T} return PossiblyUninitTangent{rdata_type(T)} @@ -587,6 +589,7 @@ Given value `p`, return the zero element associated to its reverse data type. zero_rdata(p) zero_rdata(p::IEEEFloat) = zero(p) +zero_rdata(p::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T)) @generated function zero_rdata(p::P) where {P} Rs = rdata_field_types_exprs(P) @@ -654,6 +657,9 @@ obtained from `P` alone. end @foldable can_produce_zero_rdata_from_type(::Type{<:IEEEFloat}) = true +@foldable can_produce_zero_rdata_from_type( + ::Type{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} = true @foldable can_produce_zero_rdata_from_type(::Type{<:Type}) = true @@ -737,6 +743,9 @@ function zero_rdata_from_type(::Type{P}) where {P<:NamedTuple} end zero_rdata_from_type(::Type{P}) where {P<:IEEEFloat} = zero(P) +function zero_rdata_from_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(zero(P), zero(T)) +end zero_rdata_from_type(::Type{<:Type}) = NoRData() @@ -951,6 +960,7 @@ Reconstruct the tangent `t` for which `fdata(t) == f` and `rdata(t) == r`. """ tangent(::NoFData, ::NoRData) = NoTangent() tangent(::NoFData, r::IEEEFloat) = r +tangent(::NoFData, r::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = r tangent(f::Array, ::NoRData) = f # Tuples diff --git a/src/interface.jl b/src/interface.jl index 1e4f7729ca..9786371fb3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -106,7 +106,8 @@ function __value_and_gradient!!(rule::R, fx::Vararg{CoDual,N}) where {R,N} __verify_sig(rule, fx_fwds) out, pb!! = rule(fx_fwds...) y = primal(out) - y isa IEEEFloat || throw_val_and_grad_ret_type_error(y) + (y isa IEEEFloat || y isa Dual{<:IEEEFloat,<:IEEEFloat}) || + throw_val_and_grad_ret_type_error(y) return y, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(one(y))) end @@ -194,6 +195,14 @@ struct Cache{Trule,Ty_cache,Ttangents<:Tuple} tangents::Ttangents end +tangent_type(::Type{<:Cache}) = NoTangent + +@inline zero_tangent(x::Cache) = NoTangent() + +@inline zero_tangent_internal(::Cache, ::MaybeCache) = NoTangent() + +@inline randn_tangent_internal(::AbstractRNG, ::Cache, ::MaybeCache) = NoTangent() + """ __exclude_unsupported_output(y) __exclude_func_with_unsupported_output(fx) @@ -535,7 +544,9 @@ The API guarantees that tangents are initialized at zero before the first autodi rule = build_rrule(fx...; kwargs...) tangents = map(zero_tangent, fx) y, rvs!! = rule(map((x, dx) -> CoDual(x, fdata(dx)), fx, tangents)...) - primal(y) isa IEEEFloat || throw_val_and_grad_ret_type_error(primal(y)) + _y = primal(y) + (_y isa IEEEFloat || _y isa Dual{<:IEEEFloat,<:IEEEFloat}) || + throw_val_and_grad_ret_type_error(_y) rvs!!(zero_tangent(primal(y))) # run reverse-pass to reset stacks + state return Cache(rule, nothing, tangents) end @@ -608,3 +619,34 @@ derivative of `primal(f)` at the primal values in `x` in the direction of the ta in `f` and `x`. """ value_and_derivative!!(rule::R, fx::Vararg{Dual,N}) where {R,N} = rule(fx...) + +@zero_derivative MinimalCtx Tuple{typeof(prepare_pullback_cache),Vararg} ForwardMode +@zero_derivative MinimalCtx Tuple{typeof(prepare_gradient_cache),Vararg} ForwardMode +@zero_derivative MinimalCtx Tuple{typeof(prepare_derivative_cache),Vararg} ForwardMode + +@is_primitive MinimalCtx Tuple{typeof(value_and_gradient!!),Cache,Vararg} ForwardMode + +function frule!!( + ::Dual{typeof(value_and_gradient!!)}, + cache_dual::Dual{<:Cache}, + f_dual::Dual, + x_duals::Vararg{Dual}, +) + # Extract primals and tangents + cache = primal(cache_dual) + f = primal(f_dual) + xs = map(primal, x_duals) + dxs = map(tangent, x_duals) + + y, grads = value_and_gradient!!(cache, f, xs...) + + df = tangent(f_dual) + dy = df + for (g, dx) in zip(grads, dxs) + dy = dy + g * dx + end + + dgrads = map(zero_tangent, grads) + + return Dual(y, dy), Dual(grads, dgrads) +end diff --git a/src/interpreter/forward_mode.jl b/src/interpreter/forward_mode.jl index c0ce3ae9f5..150038ec03 100644 --- a/src/interpreter/forward_mode.jl +++ b/src/interpreter/forward_mode.jl @@ -211,6 +211,38 @@ end const ATTACH_AFTER = true const ATTACH_BEFORE = false +@inline contains_bottom_type(T) = _contains_bottom_type(T, Base.IdSet{Any}()) + +function _contains_bottom_type(T, seen::Base.IdSet{Any}) + T === Union{} && return true + if T isa Union + return _contains_bottom_type(T.a, seen) || _contains_bottom_type(T.b, seen) + elseif T isa TypeVar + if T in seen + return false + end + push!(seen, T) + return _contains_bottom_type(T.ub, seen) + elseif T isa UnionAll + if T in seen + return false + end + push!(seen, T) + return _contains_bottom_type(T.body, seen) + elseif T isa DataType + if T in seen + return false + end + push!(seen, T) + for p in T.parameters + _contains_bottom_type(p, seen) && return true + end + return false + else + return false + end +end + modify_fwd_ad_stmts!(::Nothing, ::IRCode, ::SSAValue, ::Vector{Any}, ::DualInfo) = nothing modify_fwd_ad_stmts!(::GotoNode, ::IRCode, ::SSAValue, ::Vector{Any}, ::DualInfo) = nothing @@ -333,6 +365,15 @@ function modify_fwd_ad_stmts!( sig_types = map(raw_args) do x return CC.widenconst(get_forward_primal_type(info.primal_ir, x)) end + if any(contains_bottom_type, sig_types) + sig_strings = join(map(x -> sprint(show, x), sig_types), ", ") + raw_strings = join(map(x -> sprint(show, x), raw_args), ", ") + @debug "forward-mode bottom argument" sig_types = sig_strings raw_args = + raw_strings stmt + filtered = [pair for pair in zip(sig_types, raw_args) if pair[1] !== Union{}] + sig_types = map(first, filtered) + raw_args = map(last, filtered) + end sig = Tuple{sig_types...} mi = isexpr(stmt, :invoke) ? get_mi(stmt.args[1]) : missing args = map(__inc, raw_args) diff --git a/src/rrules/avoiding_non_differentiable_code.jl b/src/rrules/avoiding_non_differentiable_code.jl index 22e4d679d0..85c25c4930 100644 --- a/src/rrules/avoiding_non_differentiable_code.jl +++ b/src/rrules/avoiding_non_differentiable_code.jl @@ -107,6 +107,15 @@ end } ) +# Avoid differentiating Mooncake's rule construction in forward mode +# This prevents forward-over-reverse from descending into kw-wrapper exceptions and caches. +@zero_derivative MinimalCtx Tuple{typeof(build_rrule),Vararg} ForwardMode +@zero_derivative MinimalCtx Tuple{typeof(Core.kwcall),NamedTuple,typeof(build_rrule),Vararg} ForwardMode + +# Avoid differentiating tangent and cache constructors in forward mode +@zero_derivative MinimalCtx Tuple{typeof(zero_tangent),Vararg} ForwardMode +@zero_derivative MinimalCtx Tuple{typeof(zero_tangent_internal),Vararg} ForwardMode + function hand_written_rule_test_cases(rng_ctor, ::Val{:avoiding_non_differentiable_code}) _x = Ref(5.0) _dx = Ref(4.0) diff --git a/src/rrules/dual_arithmetic.jl b/src/rrules/dual_arithmetic.jl new file mode 100644 index 0000000000..10ce53949a --- /dev/null +++ b/src/rrules/dual_arithmetic.jl @@ -0,0 +1,216 @@ +@inline function _dual_add_pullback(dy::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return NoRData(), dy, dy +end + +@inline function _dual_sub_pullback(dy::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return NoRData(), dy, Dual(-primal(dy), -tangent(dy)) +end + +@inline function _dual_neg_pullback(dy::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return NoRData(), Dual(-primal(dy), -tangent(dy)) +end + +@is_primitive MinimalCtx Tuple{ + typeof(+),Dual{P,T},Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{Dual{P,T}}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + return CoDual(z, NoFData()), _dual_add_pullback +end + +@is_primitive MinimalCtx Tuple{typeof(+),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + pb!! = dy -> (NoRData(), dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(+),P,Dual{P,T}} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{P}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + pb!! = dy -> (NoRData(), NoRData(), dy) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(+),Dual{P,T},Integer +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{Dual{P,T}}, y::CoDual{<:Integer} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + pb!! = dy -> (NoRData(), dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(+),Integer,Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{<:Integer}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + pb!! = dy -> (NoRData(), NoRData(), dy) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(-),Dual{P,T}} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = -primal(x) + return CoDual(z, NoFData()), _dual_neg_pullback +end + +@is_primitive MinimalCtx Tuple{ + typeof(-),Dual{P,T},Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{Dual{P,T}}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + return CoDual(z, NoFData()), _dual_sub_pullback +end + +@is_primitive MinimalCtx Tuple{typeof(-),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + pb!! = dy -> (NoRData(), dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(-),P,Dual{P,T}} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{P}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + pb!! = dy -> (NoRData(), NoRData(), Dual(-primal(dy), -tangent(dy))) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(-),Dual{P,T},Integer +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{Dual{P,T}}, y::CoDual{<:Integer} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + pb!! = dy -> (NoRData(), dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(-),Integer,Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{<:Integer}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + pb!! = dy -> (NoRData(), NoRData(), Dual(-primal(dy), -tangent(dy))) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(*),Dual{P,T},Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{Dual{P,T}}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + function mul_dual_dual_pb!!(dy::Dual{P,T}) + dx = py * dy + dy_out = px * dy + return NoRData(), dx, dy_out + end + return CoDual(z, NoFData()), mul_dual_dual_pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(*),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + pb!! = dy -> (NoRData(), py * dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(*),P,Dual{P,T}} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + pb!! = dy -> (NoRData(), NoRData(), px * dy) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(/),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(/)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px / py + pb!! = dy -> (NoRData(), dy / py, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(*),Integer,Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{<:Integer}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + pb!! = dy -> (NoRData(), NoRData(), px * dy) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(*),Dual{P,T},Integer +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{Dual{P,T}}, y::CoDual{<:Integer} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + pb!! = dy -> (NoRData(), py * dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(^),Dual{P,T},Int} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(^)}, x::CoDual{Dual{P,T}}, n::CoDual{Int} +) where {P<:IEEEFloat,T<:IEEEFloat} + px = primal(x) + pn = primal(n) + z = px^pn + function pow_dual_int_pb!!(dy::Dual{P,T}) + dx = pn * px^(pn - 1) * dy + return NoRData(), dx, NoRData() + end + return CoDual(z, NoFData()), pow_dual_int_pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(^),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(^)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px^py + function pow_dual_float_pb!!(dy::Dual{P,T}) + dx = py * px^(py - one(P)) * dy + return NoRData(), dx, NoRData() + end + return CoDual(z, NoFData()), pow_dual_float_pb!! +end diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index ea74e84b6d..7b81715a44 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -3,10 +3,188 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using DifferentiationInterface, DifferentiationInterfaceTest +import DifferentiationInterface as DI using Mooncake: Mooncake +using Test test_differentiation( [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; excluded=SECOND_ORDER, logging=true, ) + +# Test Hessian computation using forward-over-reverse with DITest scenarios. +test_differentiation( + [SecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))]; + excluded=vcat(FIRST_ORDER, [:hvp, :second_derivative]), + logging=true, +) + +@testset "Mooncake Hessian tests" begin + backend = SecondOrder( + AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing) + ) + + # Sum: Hessian is zero + @testset "sum" begin + @test DI.hessian(sum, backend, [2.0]) == [0.0] + end + + # Rosenbrock 2D at [1.2, 1.2] + @testset "Rosenbrock" begin + rosen(z) = (1.0 - z[1])^2 + 100.0 * (z[2] - z[1]^2)^2 + H = DI.hessian(rosen, backend, [1.2, 1.2]) + @test isapprox(H, [1250.0 -480.0; -480.0 200.0]; rtol=1e-10, atol=1e-12) + end + + # Test higher integer powers (fixed by adding frule for ^(Float, Int)) + @testset "higher powers" begin + @test DI.hessian(x -> x[1]^4, backend, [2.0]) ≈ [48.0] + @test DI.hessian(x -> x[1]^6, backend, [2.0]) ≈ [480.0] + end + + @testset "https://github.com/chalk-lab/Mooncake.jl/issues/632" begin + function gams_objective(x) + return ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + x[1] * + x[1] + + x[10] * + x[10] + ) * + ( + x[1] * + x[1] + + x[10] * + x[10] + ) - + 4 * + x[1] + ) + + 3 + ) + + ( + x[2] * + x[2] + + x[10] * + x[10] + ) * + ( + x[2] * + x[2] + + x[10] * + x[10] + ) + ) - + 4 * + x[2] + ) + + 3 + ) + + ( + x[3] * + x[3] + + x[10] * + x[10] + ) * + ( + x[3] * + x[3] + + x[10] * + x[10] + ) + ) - + 4 * + x[3] + ) + + 3 + ) + + ( + x[4] * + x[4] + + x[10] * + x[10] + ) * ( + x[4] * + x[4] + + x[10] * + x[10] + ) + ) - + 4 * x[4] + ) + 3 + ) + + ( + x[5] * x[5] + + x[10] * x[10] + ) * ( + x[5] * x[5] + + x[10] * x[10] + ) + ) - 4 * x[5] + ) + 3 + ) + + (x[6] * x[6] + x[10] * x[10]) * + (x[6] * x[6] + x[10] * x[10]) + ) - 4 * x[6] + ) + 3 + ) + + (x[7] * x[7] + x[10] * x[10]) * + (x[7] * x[7] + x[10] * x[10]) + ) - 4 * x[7] + ) + 3 + ) + + (x[8] * x[8] + x[10] * x[10]) * + (x[8] * x[8] + x[10] * x[10]) + ) - 4 * x[8] + ) + 3 + ) + (x[9] * x[9] + x[10] * x[10]) * (x[9] * x[9] + x[10] * x[10]) + ) - 4 * x[9] + ) + 3 + ) + end + x0 = [0.0; fill(1.0, 9)] + H = DI.hessian(gams_objective, backend, x0) + + # Expected Hessian at x0: + # - H[1,1] = 4 (since x₁=0) + # - H[i,i] = 16 for i ∈ 2:9 (since xᵢ=1, x₁₀=1) + # - H[10,10] = 140 (sum of contributions from all 9 terms) + # - H[i,10] = H[10,i] = 8xᵢx₁₀ = 0 for i=1, 8 for i∈2:9 + H_expected = zeros(10, 10) + H_expected[1, 1] = 4.0 + for i in 2:9 + H_expected[i, i] = 16.0 + H_expected[i, 10] = 8.0 + H_expected[10, i] = 8.0 + end + H_expected[10, 10] = 140.0 + + @test H ≈ H_expected + end +end