-
Notifications
You must be signed in to change notification settings - Fork 29
Fix Hessian #833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Hessian #833
Changes from all commits
f23b8f6
b1984f0
021e185
aebb345
150721a
1861d59
b0b69e8
6533eb9
727481a
b028e78
324b121
2f6a2e5
2acded6
933efc4
7fc7007
4dd6362
ca5f6fa
666fa4b
c383c05
9bb592c
961314a
8f96ece
4f3be58
1dcd242
146457c
726ce5d
c5bb54b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| module MooncakeDifferentiationInterfaceExt | ||
|
|
||
| using Mooncake: | ||
| Mooncake, @is_primitive, MinimalCtx, ForwardMode, Dual, primal, tangent, NoTangent | ||
| import DifferentiationInterface as DI | ||
|
|
||
| # Mark shuffled_gradient as forward-mode primitive to avoid expensive type inference hang. | ||
| # This prevents build_frule from trying to derive rules for the complex gradient closure. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain clearly with a small example, why Mooncake/Julia struggle with |
||
| @is_primitive MinimalCtx ForwardMode Tuple{typeof(DI.shuffled_gradient),Vararg} | ||
| @is_primitive MinimalCtx ForwardMode Tuple{typeof(DI.shuffled_gradient!),Vararg} | ||
|
|
||
| # Helper to create Dual array from primal and tangent arrays | ||
| _make_dual_array(x::AbstractArray, dx::AbstractArray) = Dual.(x, dx) | ||
| _make_dual_array(x, dx) = Dual(x, dx) | ||
|
|
||
| # Helper to extract primal and tangent from Dual array | ||
| _extract_primals(arr::AbstractArray{<:Dual}) = primal.(arr) | ||
| _extract_primals(d::Dual) = primal(d) | ||
| _extract_tangents(arr::AbstractArray{<:Dual}) = tangent.(arr) | ||
| _extract_tangents(d::Dual) = tangent(d) | ||
|
|
||
| # frule for shuffled_gradient without prep | ||
| # shuffled_gradient(x, f, backend, rewrap, contexts...) -> gradient(f, backend, x, contexts...) | ||
| function Mooncake.frule!!( | ||
| ::Dual{typeof(DI.shuffled_gradient)}, | ||
| x_dual::Dual, | ||
| f_dual::Dual, | ||
| backend_dual::Dual, | ||
| rewrap_dual::Dual, | ||
| context_duals::Vararg{Dual}, | ||
| ) | ||
| # Extract primals and tangents | ||
| x = primal(x_dual) | ||
| dx = tangent(x_dual) | ||
| f = primal(f_dual) | ||
| backend = primal(backend_dual) | ||
| rewrap = primal(rewrap_dual) | ||
| contexts = map(d -> primal(d), context_duals) | ||
|
|
||
| # Create Dual inputs: each element is Dual(x[i], dx[i]) | ||
| # This allows the Hvp to be computed via forward-over-reverse | ||
| x_with_duals = _make_dual_array(x, dx) | ||
|
|
||
| # Call gradient with Dual inputs | ||
| # Since Dual{Float64,Float64} is self-tangent, reverse mode handles it correctly | ||
| grad_duals = DI.shuffled_gradient(x_with_duals, f, backend, rewrap, contexts...) | ||
|
|
||
| # Extract primal (gradient) and tangent (Hvp) from the Dual outputs | ||
| grad_primal = _extract_primals(grad_duals) | ||
| grad_tangent = _extract_tangents(grad_duals) | ||
|
|
||
| return Dual(grad_primal, grad_tangent) | ||
| end | ||
|
|
||
| # frule for shuffled_gradient with prep | ||
| function Mooncake.frule!!( | ||
| ::Dual{typeof(DI.shuffled_gradient)}, | ||
| x_dual::Dual, | ||
| f_dual::Dual, | ||
| prep_dual::Dual, | ||
| backend_dual::Dual, | ||
| rewrap_dual::Dual, | ||
| context_duals::Vararg{Dual}, | ||
| ) | ||
| x = primal(x_dual) | ||
| dx = tangent(x_dual) | ||
| f = primal(f_dual) | ||
| prep = primal(prep_dual) | ||
| backend = primal(backend_dual) | ||
| rewrap = primal(rewrap_dual) | ||
| contexts = map(d -> primal(d), context_duals) | ||
|
|
||
| x_with_duals = _make_dual_array(x, dx) | ||
| grad_duals = DI.shuffled_gradient(x_with_duals, f, prep, backend, rewrap, contexts...) | ||
|
|
||
| grad_primal = _extract_primals(grad_duals) | ||
| grad_tangent = _extract_tangents(grad_duals) | ||
|
|
||
| return Dual(grad_primal, grad_tangent) | ||
| end | ||
|
|
||
| # frule for shuffled_gradient! (in-place version) | ||
| function Mooncake.frule!!( | ||
| ::Dual{typeof(DI.shuffled_gradient!)}, | ||
| grad_dual::Dual, | ||
| x_dual::Dual, | ||
| f_dual::Dual, | ||
| backend_dual::Dual, | ||
| rewrap_dual::Dual, | ||
| context_duals::Vararg{Dual}, | ||
| ) | ||
| grad = primal(grad_dual) | ||
| dgrad = tangent(grad_dual) # Tangent storage for gradient (where Hvp goes) | ||
| x = primal(x_dual) | ||
| dx = tangent(x_dual) | ||
| f = primal(f_dual) | ||
| backend = primal(backend_dual) | ||
| rewrap = primal(rewrap_dual) | ||
| contexts = map(d -> primal(d), context_duals) | ||
|
|
||
| x_with_duals = _make_dual_array(x, dx) | ||
| # Allocate Dual buffer for in-place gradient | ||
| grad_duals = _make_dual_array(grad, similar(grad)) | ||
| DI.shuffled_gradient!(grad_duals, x_with_duals, f, backend, rewrap, contexts...) | ||
|
|
||
| # Copy primal (gradient) back to grad | ||
| grad .= _extract_primals(grad_duals) | ||
| # Copy tangent (Hvp) back to dgrad | ||
| dgrad .= _extract_tangents(grad_duals) | ||
|
|
||
| return Dual(nothing, NoTangent()) | ||
| end | ||
|
|
||
| # frule for shuffled_gradient! with prep | ||
| function Mooncake.frule!!( | ||
| ::Dual{typeof(DI.shuffled_gradient!)}, | ||
| grad_dual::Dual, | ||
| x_dual::Dual, | ||
| f_dual::Dual, | ||
| prep_dual::Dual, | ||
| backend_dual::Dual, | ||
| rewrap_dual::Dual, | ||
| context_duals::Vararg{Dual}, | ||
| ) | ||
| grad = primal(grad_dual) | ||
| dgrad = tangent(grad_dual) # Tangent storage for gradient (where Hvp goes) | ||
| x = primal(x_dual) | ||
| dx = tangent(x_dual) | ||
| f = primal(f_dual) | ||
| prep = primal(prep_dual) | ||
| backend = primal(backend_dual) | ||
| rewrap = primal(rewrap_dual) | ||
| contexts = map(d -> primal(d), context_duals) | ||
|
|
||
| x_with_duals = _make_dual_array(x, dx) | ||
| grad_duals = _make_dual_array(grad, similar(grad)) | ||
| DI.shuffled_gradient!(grad_duals, x_with_duals, f, prep, backend, rewrap, contexts...) | ||
|
|
||
| # Copy primal (gradient) back to grad | ||
| grad .= _extract_primals(grad_duals) | ||
| # Copy tangent (Hvp) back to dgrad | ||
| dgrad .= _extract_tangents(grad_duals) | ||
|
|
||
| return Dual(nothing, NoTangent()) | ||
| end | ||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,3 +55,140 @@ verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x) | |
| function Dual(x::Type{P}, dx::NoTangent) where {P} | ||
| return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(x, dx) | ||
| end | ||
|
|
||
| # Dual of numeric types is self-tangent | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain when |
||
| @inline tangent_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual{P,T} | ||
|
|
||
| @inline zero_tangent_internal( | ||
| x::Dual{P,T}, ::MaybeCache | ||
| ) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T)) | ||
|
|
||
| @inline function randn_tangent_internal( | ||
| rng::AbstractRNG, x::Dual{P,T}, ::MaybeCache | ||
| ) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(randn(rng, P), randn(rng, T)) | ||
| end | ||
|
|
||
| @inline function increment!!(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) | ||
| end | ||
|
|
||
| @inline set_to_zero_internal!!( | ||
| ::SetToZeroCache, x::Dual{P,T} | ||
| ) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T)) | ||
|
|
||
| @inline function increment_internal!!( | ||
| ::IncCache, x::Dual{P,T}, y::Dual{P,T} | ||
| ) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) | ||
| end | ||
|
|
||
| Base.one(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(one(P), zero(T)) | ||
| function Base.one(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(one(primal(x)), zero(tangent(x))) | ||
| end | ||
|
|
||
| # Arithmetic operations | ||
| function Base.:+(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) + y, tangent(x)) | ||
| end | ||
| function Base.:+(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(x + primal(y), tangent(y)) | ||
| end | ||
| function Base.:+(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) | ||
| end | ||
| function Base.:+(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) + y, tangent(x)) | ||
| end | ||
| function Base.:+(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(x + primal(y), tangent(y)) | ||
| end | ||
|
|
||
| # Subtraction | ||
| Base.:-(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(-primal(x), -tangent(x)) | ||
| function Base.:-(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) - y, tangent(x)) | ||
| end | ||
| function Base.:-(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(x - primal(y), -tangent(y)) | ||
| end | ||
| function Base.:-(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) - primal(y), tangent(x) - tangent(y)) | ||
| end | ||
| function Base.:-(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) - y, tangent(x)) | ||
| end | ||
| function Base.:-(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(x - primal(y), -tangent(y)) | ||
| end | ||
|
|
||
| # Multiplication (product rule) | ||
| function Base.:*(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) * y, tangent(x) * y) | ||
| end | ||
| function Base.:*(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(x * primal(y), x * tangent(y)) | ||
| end | ||
| function Base.:*(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) * primal(y), primal(x) * tangent(y) + tangent(x) * primal(y)) | ||
| end | ||
| function Base.:*(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) * y, tangent(x) * y) | ||
| end | ||
| function Base.:*(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(x * primal(y), x * tangent(y)) | ||
| end | ||
|
|
||
| # Division (quotient rule) | ||
| function Base.:/(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x) / y, tangent(x) / y) | ||
| end | ||
| function Base.:/(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(x / primal(y), -x * tangent(y) / primal(y)^2) | ||
| end | ||
| function Base.:/(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual( | ||
| primal(x) / primal(y), | ||
| (tangent(x) * primal(y) - primal(x) * tangent(y)) / primal(y)^2, | ||
| ) | ||
| end | ||
|
|
||
| # Power (chain rule) | ||
| function Base.:^(x::Dual{P,T}, n::Integer) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x)^n, n * primal(x)^(n - 1) * tangent(x)) | ||
| end | ||
| function Base.:^(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual(primal(x)^y, y * primal(x)^(y - 1) * tangent(x)) | ||
| end | ||
|
|
||
| # Comparison (use primal for comparisons) | ||
| Base.:<(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) < y | ||
| Base.:<(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x < primal(y) | ||
| function Base.:<(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return primal(x) < primal(y) | ||
| end | ||
| Base.:>(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) > y | ||
| Base.:>(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x > primal(y) | ||
| function Base.:>(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return primal(x) > primal(y) | ||
| end | ||
| Base.:<=(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) <= y | ||
| Base.:<=(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x <= primal(y) | ||
| function Base.:<=(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return primal(x) <= primal(y) | ||
| end | ||
| Base.:>=(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) >= y | ||
| Base.:>=(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x >= primal(y) | ||
| function Base.:>=(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return primal(x) >= primal(y) | ||
| end | ||
|
|
||
| # Conversion and promotion | ||
| Base.convert(::Type{Dual{P,T}}, x::P) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(x, zero(T)) | ||
| function Base.promote_rule(::Type{Dual{P,T}}, ::Type{P}) where {P<:IEEEFloat,T<:IEEEFloat} | ||
| return Dual{P,T} | ||
| end | ||
|
|
||
| LinearAlgebra.transpose(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x | ||
| LinearAlgebra.adjoint(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this file to the DI repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Until @sunxd3 explains why it is necessary, I'm not convinced this file should be anywhere.