Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f23b8f6
add fixes
sunxd3 Nov 2, 2025
b1984f0
add missing imports
sunxd3 Nov 2, 2025
021e185
add claude related files
sunxd3 Nov 9, 2025
aebb345
update agent memory
sunxd3 Nov 9, 2025
150721a
try fixing CI error
sunxd3 Nov 10, 2025
1861d59
Merge branch 'main' into sunxd/fix_hessian_mwe
sunxd3 Nov 10, 2025
b0b69e8
add forward over reverse test
sunxd3 Nov 10, 2025
6533eb9
add more proper tests
sunxd3 Nov 10, 2025
727481a
enable all test cases
sunxd3 Nov 11, 2025
b028e78
remove claude related files
sunxd3 Nov 14, 2025
324b121
remove repeated tests
sunxd3 Nov 14, 2025
2f6a2e5
Merge branch 'main' into sunxd/fix_hessian_mwe
sunxd3 Nov 14, 2025
2acded6
format
sunxd3 Nov 14, 2025
933efc4
fix format
sunxd3 Nov 14, 2025
7fc7007
Revert "fix format"
sunxd3 Nov 14, 2025
4dd6362
disable second order DI tests
sunxd3 Nov 18, 2025
ca5f6fa
us DITests Hessian test
sunxd3 Nov 19, 2025
666fa4b
Merge branch 'main' into sunxd/fix_hessian_mwe
sunxd3 Nov 19, 2025
c383c05
bring back old tests
sunxd3 Nov 19, 2025
9bb592c
Merge branch 'main' into sunxd/fix_hessian_mwe
sunxd3 Nov 21, 2025
961314a
Merge branch 'main' into sunxd/fix_hessian_mwe
sunxd3 Dec 1, 2025
8f96ece
pushing the envelope
sunxd3 Dec 1, 2025
4f3be58
Merge branch 'main' into sunxd/fix_hessian_mwe
sunxd3 Dec 2, 2025
1dcd242
Merge branch 'main' into sunxd/fix_hessian_mwe
sunxd3 Dec 3, 2025
146457c
Merge branch 'main' into sunxd/fix_hessian_mwe
sunxd3 Dec 4, 2025
726ce5d
add DI extension, seems DI wrap reversemode as a closure
sunxd3 Dec 4, 2025
c5bb54b
fix and add Tim Holy's example
sunxd3 Dec 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[weakdeps]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
Expand All @@ -33,6 +34,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
[extensions]
MooncakeAllocCheckExt = "AllocCheck"
MooncakeCUDAExt = "CUDA"
MooncakeDifferentiationInterfaceExt = "DifferentiationInterface"
MooncakeDynamicExpressionsExt = "DynamicExpressions"
MooncakeFluxExt = "Flux"
MooncakeFunctionWrappersExt = "FunctionWrappers"
Expand All @@ -51,6 +53,7 @@ CUDA = "5"
ChainRules = "1.71.0"
ChainRulesCore = "1"
DiffTests = "0.1"
DifferentiationInterface = "0.7"
DispatchDoctor = "0.4.26"
DynamicExpressions = "2"
ExprTools = "0.1"
Expand Down
147 changes: 147 additions & 0 deletions ext/MooncakeDifferentiationInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
module MooncakeDifferentiationInterfaceExt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this file to the DI repo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Until @sunxd3 explains why it is necessary, I'm not convinced this file should be anywhere.


using Mooncake:
Mooncake, @is_primitive, MinimalCtx, ForwardMode, Dual, primal, tangent, NoTangent
import DifferentiationInterface as DI

# Mark shuffled_gradient as forward-mode primitive to avoid expensive type inference hang.
# This prevents build_frule from trying to derive rules for the complex gradient closure.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain clearly with a small example, why Mooncake/Julia struggle with shuffled_gradient.

@is_primitive MinimalCtx ForwardMode Tuple{typeof(DI.shuffled_gradient),Vararg}
@is_primitive MinimalCtx ForwardMode Tuple{typeof(DI.shuffled_gradient!),Vararg}

# Helper to create Dual array from primal and tangent arrays
_make_dual_array(x::AbstractArray, dx::AbstractArray) = Dual.(x, dx)
_make_dual_array(x, dx) = Dual(x, dx)

# Helper to extract primal and tangent from Dual array
_extract_primals(arr::AbstractArray{<:Dual}) = primal.(arr)
_extract_primals(d::Dual) = primal(d)
_extract_tangents(arr::AbstractArray{<:Dual}) = tangent.(arr)
_extract_tangents(d::Dual) = tangent(d)

# frule for shuffled_gradient without prep
# shuffled_gradient(x, f, backend, rewrap, contexts...) -> gradient(f, backend, x, contexts...)
function Mooncake.frule!!(
::Dual{typeof(DI.shuffled_gradient)},
x_dual::Dual,
f_dual::Dual,
backend_dual::Dual,
rewrap_dual::Dual,
context_duals::Vararg{Dual},
)
# Extract primals and tangents
x = primal(x_dual)
dx = tangent(x_dual)
f = primal(f_dual)
backend = primal(backend_dual)
rewrap = primal(rewrap_dual)
contexts = map(d -> primal(d), context_duals)

# Create Dual inputs: each element is Dual(x[i], dx[i])
# This allows the Hvp to be computed via forward-over-reverse
x_with_duals = _make_dual_array(x, dx)

# Call gradient with Dual inputs
# Since Dual{Float64,Float64} is self-tangent, reverse mode handles it correctly
grad_duals = DI.shuffled_gradient(x_with_duals, f, backend, rewrap, contexts...)

# Extract primal (gradient) and tangent (Hvp) from the Dual outputs
grad_primal = _extract_primals(grad_duals)
grad_tangent = _extract_tangents(grad_duals)

return Dual(grad_primal, grad_tangent)
end

# frule for shuffled_gradient with prep
function Mooncake.frule!!(
::Dual{typeof(DI.shuffled_gradient)},
x_dual::Dual,
f_dual::Dual,
prep_dual::Dual,
backend_dual::Dual,
rewrap_dual::Dual,
context_duals::Vararg{Dual},
)
x = primal(x_dual)
dx = tangent(x_dual)
f = primal(f_dual)
prep = primal(prep_dual)
backend = primal(backend_dual)
rewrap = primal(rewrap_dual)
contexts = map(d -> primal(d), context_duals)

x_with_duals = _make_dual_array(x, dx)
grad_duals = DI.shuffled_gradient(x_with_duals, f, prep, backend, rewrap, contexts...)

grad_primal = _extract_primals(grad_duals)
grad_tangent = _extract_tangents(grad_duals)

return Dual(grad_primal, grad_tangent)
end

# frule for shuffled_gradient! (in-place version)
function Mooncake.frule!!(
::Dual{typeof(DI.shuffled_gradient!)},
grad_dual::Dual,
x_dual::Dual,
f_dual::Dual,
backend_dual::Dual,
rewrap_dual::Dual,
context_duals::Vararg{Dual},
)
grad = primal(grad_dual)
dgrad = tangent(grad_dual) # Tangent storage for gradient (where Hvp goes)
x = primal(x_dual)
dx = tangent(x_dual)
f = primal(f_dual)
backend = primal(backend_dual)
rewrap = primal(rewrap_dual)
contexts = map(d -> primal(d), context_duals)

x_with_duals = _make_dual_array(x, dx)
# Allocate Dual buffer for in-place gradient
grad_duals = _make_dual_array(grad, similar(grad))
DI.shuffled_gradient!(grad_duals, x_with_duals, f, backend, rewrap, contexts...)

# Copy primal (gradient) back to grad
grad .= _extract_primals(grad_duals)
# Copy tangent (Hvp) back to dgrad
dgrad .= _extract_tangents(grad_duals)

return Dual(nothing, NoTangent())
end

# frule for shuffled_gradient! with prep
function Mooncake.frule!!(
::Dual{typeof(DI.shuffled_gradient!)},
grad_dual::Dual,
x_dual::Dual,
f_dual::Dual,
prep_dual::Dual,
backend_dual::Dual,
rewrap_dual::Dual,
context_duals::Vararg{Dual},
)
grad = primal(grad_dual)
dgrad = tangent(grad_dual) # Tangent storage for gradient (where Hvp goes)
x = primal(x_dual)
dx = tangent(x_dual)
f = primal(f_dual)
prep = primal(prep_dual)
backend = primal(backend_dual)
rewrap = primal(rewrap_dual)
contexts = map(d -> primal(d), context_duals)

x_with_duals = _make_dual_array(x, dx)
grad_duals = _make_dual_array(grad, similar(grad))
DI.shuffled_gradient!(grad_duals, x_with_duals, f, prep, backend, rewrap, contexts...)

# Copy primal (gradient) back to grad
grad .= _extract_primals(grad_duals)
# Copy tangent (Hvp) back to dgrad
dgrad .= _extract_tangents(grad_duals)

return Dual(nothing, NoTangent())
end

end
1 change: 1 addition & 0 deletions src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ else
include(joinpath("rrules", "array_legacy.jl"))
end
include(joinpath("rrules", "performance_patches.jl"))
include(joinpath("rrules", "dual_arithmetic.jl"))

include("interface.jl")
include("config.jl")
Expand Down
137 changes: 137 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,140 @@ verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x)
function Dual(x::Type{P}, dx::NoTangent) where {P}
return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(x, dx)
end

# Dual of numeric types is self-tangent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain when tangent_type(::Type{Dual}) will be called during forward-over-reverse? It is not self-evident why this is needed.

@inline tangent_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual{P,T}

@inline zero_tangent_internal(
x::Dual{P,T}, ::MaybeCache
) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T))

@inline function randn_tangent_internal(
rng::AbstractRNG, x::Dual{P,T}, ::MaybeCache
) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(randn(rng, P), randn(rng, T))
end

@inline function increment!!(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) + primal(y), tangent(x) + tangent(y))
end

@inline set_to_zero_internal!!(
::SetToZeroCache, x::Dual{P,T}
) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T))

@inline function increment_internal!!(
::IncCache, x::Dual{P,T}, y::Dual{P,T}
) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) + primal(y), tangent(x) + tangent(y))
end

Base.one(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(one(P), zero(T))
function Base.one(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(one(primal(x)), zero(tangent(x)))
end

# Arithmetic operations
function Base.:+(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) + y, tangent(x))
end
function Base.:+(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(x + primal(y), tangent(y))
end
function Base.:+(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) + primal(y), tangent(x) + tangent(y))
end
function Base.:+(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) + y, tangent(x))
end
function Base.:+(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(x + primal(y), tangent(y))
end

# Subtraction
Base.:-(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(-primal(x), -tangent(x))
function Base.:-(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) - y, tangent(x))
end
function Base.:-(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(x - primal(y), -tangent(y))
end
function Base.:-(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) - primal(y), tangent(x) - tangent(y))
end
function Base.:-(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) - y, tangent(x))
end
function Base.:-(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(x - primal(y), -tangent(y))
end

# Multiplication (product rule)
function Base.:*(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) * y, tangent(x) * y)
end
function Base.:*(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(x * primal(y), x * tangent(y))
end
function Base.:*(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) * primal(y), primal(x) * tangent(y) + tangent(x) * primal(y))
end
function Base.:*(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) * y, tangent(x) * y)
end
function Base.:*(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(x * primal(y), x * tangent(y))
end

# Division (quotient rule)
function Base.:/(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x) / y, tangent(x) / y)
end
function Base.:/(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(x / primal(y), -x * tangent(y) / primal(y)^2)
end
function Base.:/(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(
primal(x) / primal(y),
(tangent(x) * primal(y) - primal(x) * tangent(y)) / primal(y)^2,
)
end

# Power (chain rule)
function Base.:^(x::Dual{P,T}, n::Integer) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x)^n, n * primal(x)^(n - 1) * tangent(x))
end
function Base.:^(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(primal(x)^y, y * primal(x)^(y - 1) * tangent(x))
end

# Comparison (use primal for comparisons)
Base.:<(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) < y
Base.:<(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x < primal(y)
function Base.:<(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return primal(x) < primal(y)
end
Base.:>(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) > y
Base.:>(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x > primal(y)
function Base.:>(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return primal(x) > primal(y)
end
Base.:<=(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) <= y
Base.:<=(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x <= primal(y)
function Base.:<=(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return primal(x) <= primal(y)
end
Base.:>=(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) >= y
Base.:>=(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x >= primal(y)
function Base.:>=(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat}
return primal(x) >= primal(y)
end

# Conversion and promotion
Base.convert(::Type{Dual{P,T}}, x::P) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(x, zero(T))
function Base.promote_rule(::Type{Dual{P,T}}, ::Type{P}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual{P,T}
end

LinearAlgebra.transpose(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x
LinearAlgebra.adjoint(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x
10 changes: 10 additions & 0 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ fdata_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?"))
@foldable fdata_type(::Type{Union{}}) = Union{}

fdata_type(::Type{T}) where {T<:IEEEFloat} = NoFData
fdata_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = NoFData

function fdata_type(::Type{PossiblyUninitTangent{T}}) where {T}
Tfields = fdata_type(T)
Expand Down Expand Up @@ -437,6 +438,7 @@ rdata_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?"))
@foldable rdata_type(::Type{Union{}}) = Union{}

rdata_type(::Type{T}) where {T<:IEEEFloat} = T
rdata_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual{P,T}

function rdata_type(::Type{PossiblyUninitTangent{T}}) where {T}
return PossiblyUninitTangent{rdata_type(T)}
Expand Down Expand Up @@ -587,6 +589,7 @@ Given value `p`, return the zero element associated to its reverse data type.
zero_rdata(p)

zero_rdata(p::IEEEFloat) = zero(p)
zero_rdata(p::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T))

@generated function zero_rdata(p::P) where {P}
Rs = rdata_field_types_exprs(P)
Expand Down Expand Up @@ -654,6 +657,9 @@ obtained from `P` alone.
end

@foldable can_produce_zero_rdata_from_type(::Type{<:IEEEFloat}) = true
@foldable can_produce_zero_rdata_from_type(
::Type{Dual{P,T}}
) where {P<:IEEEFloat,T<:IEEEFloat} = true

@foldable can_produce_zero_rdata_from_type(::Type{<:Type}) = true

Expand Down Expand Up @@ -737,6 +743,9 @@ function zero_rdata_from_type(::Type{P}) where {P<:NamedTuple}
end

zero_rdata_from_type(::Type{P}) where {P<:IEEEFloat} = zero(P)
function zero_rdata_from_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat}
return Dual(zero(P), zero(T))
end

zero_rdata_from_type(::Type{<:Type}) = NoRData()

Expand Down Expand Up @@ -951,6 +960,7 @@ Reconstruct the tangent `t` for which `fdata(t) == f` and `rdata(t) == r`.
"""
tangent(::NoFData, ::NoRData) = NoTangent()
tangent(::NoFData, r::IEEEFloat) = r
tangent(::NoFData, r::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = r
tangent(f::Array, ::NoRData) = f

# Tuples
Expand Down
Loading
Loading