Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
fd8f884
Fix forward-over-reverse, add tests, DI doesn't work yet
sunxd3 Dec 6, 2025
1bb341d
fix version specific issue
sunxd3 Dec 7, 2025
4869aa9
Use NativeInterpreter in frule_type to avoid maybe_primitive recursion
sunxd3 Dec 7, 2025
d6480be
add frule for _build_rule!
sunxd3 Dec 7, 2025
9159ef9
revert the nativeinterpreter change
sunxd3 Dec 7, 2025
a426235
add has_equal_data_internal for MistyClosureTangent
sunxd3 Dec 7, 2025
231a696
add rule for jl_genericmemory_owner to fix 1.11
sunxd3 Dec 7, 2025
1d28b33
add frule for jl_alloc_array_1d to fix 1.10 error
sunxd3 Dec 7, 2025
4bb3c61
aboid hardcoded UInt length
sunxd3 Dec 7, 2025
64410c6
formatting
sunxd3 Dec 7, 2025
3f8d7eb
more x86 compacy
sunxd3 Dec 8, 2025
029a777
deal with Union{} bottom type, make rule type Any for LazyFRule
sunxd3 Dec 8, 2025
03eb3c0
revert the Any type change for LazyFRule
sunxd3 Dec 8, 2025
14b6fcf
test forward over forward to see
sunxd3 Dec 8, 2025
4339183
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 8, 2025
5db1ce2
remove added DI tests
sunxd3 Dec 8, 2025
6578aa6
Update forward_over_reverse.jl
yebai Dec 9, 2025
5db16f0
Add skip_world_age_check kwarg to build_frule for MistyClosure support
sunxd3 Dec 9, 2025
09c496a
Add rrule!! for _build_rule! that throws on reverse-over-reverse
sunxd3 Dec 9, 2025
b7413d5
Remove unnecessary rules: literal_pow, push!, jl_genericmemory_owner
sunxd3 Dec 9, 2025
c2f14b0
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 12, 2025
2a9a680
Preserve primitive inlining policy in optimise_ir! for forward-over-r…
sunxd3 Dec 17, 2025
e18edfd
Restrict getfield frule!! to AbstractArray to avoid ambiguity with St…
sunxd3 Dec 17, 2025
e324759
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 17, 2025
3168a13
Fix Julia 1.12 Future assertion by using sv.interp for wrapper interp…
sunxd3 Dec 17, 2025
ed10d38
Revert interpreter forwarding changes (too hairy for now)
sunxd3 Dec 17, 2025
d79f375
Add back jl_genericmemory_owner frule/rrule for Julia 1.11+
sunxd3 Dec 17, 2025
c9109fe
Move `jl_genericmemory_owner` frule to test file.
sunxd3 Dec 17, 2025
7d258bf
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 17, 2025
f6ba807
version bump
sunxd3 Dec 17, 2025
d90977a
Merge branch 'main' into sunxd/f-o-r
yebai Dec 19, 2025
3334682
Add frules to make DI interface work.
sunxd3 Dec 29, 2025
ae7dea4
Actually enable f-o-r tests
sunxd3 Dec 29, 2025
5769127
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 29, 2025
7a78468
Move DI interface rules to avoiding_non_differentiable_code.jl
sunxd3 Dec 29, 2025
403f610
Move interface.jl include before rules
sunxd3 Dec 29, 2025
344cb39
Move jl_genericmemory_owner frule/rrule to main source for DI tests
sunxd3 Dec 29, 2025
f269afe
Move higher-order differentiation rules to high_order_derivative_patc…
sunxd3 Dec 29, 2025
edc3aca
Split DI tests into first-order and second-order CI jobs
sunxd3 Dec 30, 2025
7ed82fd
Restrict second-order DI tests to hessian only
sunxd3 Dec 30, 2025
b10b372
Move FoR tests to rules/, add world age docs for _dual_mc
sunxd3 Dec 30, 2025
c1c2be9
Merge branch 'main' into sunxd/f-o-r
sunxd3 Dec 30, 2025
d2713d4
Bump version to 0.4.193 (#913)
Copilot Dec 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ jobs:
'rules/tasks',
'rules/twice_precision',
'rules/performance_patches',
'rules/high_order_derivative_patches',
]
version:
- 'lts'
Expand Down Expand Up @@ -87,6 +88,7 @@ jobs:
matrix:
test_group: [
{test_type: 'ext', label: 'differentiation_interface'},
{test_type: 'ext', label: 'differentiation_interface_second_order'},
{test_type: 'ext', label: 'dynamic_expressions'},
{test_type: 'ext', label: 'flux'},
{test_type: 'ext', label: 'function_wrappers'},
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.192"
version = "0.4.193"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
5 changes: 3 additions & 2 deletions src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ end
include("tools_for_rules.jl")
@unstable include("test_utils.jl")
@unstable include("test_resources.jl")
include("interface.jl")

include(joinpath("rules", "avoiding_non_differentiable_code.jl"))
include(joinpath("rules", "blas.jl"))
Expand All @@ -151,10 +152,10 @@ else
include(joinpath("rules", "array_legacy.jl"))
end

# Including this in DispatchDoctor causes precompilation error.
# Including this in DispatchDoctor causes precompilation error.
@unstable include(joinpath("rules", "performance_patches.jl"))
include(joinpath("rules", "high_order_derivative_patches.jl"))

include("interface.jl")
include("config.jl")
include("developer_tools.jl")

Expand Down
46 changes: 43 additions & 3 deletions src/interpreter/forward_mode.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
# Check if a type contains Union{} (bottom type) anywhere in its structure.
# This can happen with unreachable code or failed type inference.
@inline contains_bottom_type(T) = _contains_bottom_type(T, Base.IdSet{Any}())

function _contains_bottom_type(T, seen::Base.IdSet{Any})
T === Union{} && return true
if T isa Union
return _contains_bottom_type(T.a, seen) || _contains_bottom_type(T.b, seen)
elseif T isa TypeVar
T in seen && return false
push!(seen, T)
return _contains_bottom_type(T.ub, seen)
elseif T isa UnionAll
T in seen && return false
push!(seen, T)
return _contains_bottom_type(T.body, seen)
elseif T isa DataType
T in seen && return false
push!(seen, T)
for p in T.parameters
_contains_bottom_type(p, seen) && return true
end
return false
else
return false
end
end

function build_frule(args...; debug_mode=false, silence_debug_messages=true)
sig = _typeof(TestUtils.__get_primals(args))
interp = get_interpreter(ForwardMode)
Expand All @@ -16,19 +44,27 @@ end
sig_or_mi;
debug_mode=false,
silence_debug_messages=true,
skip_world_age_check=false,
) where {C}

Returns a function which performs forward-mode AD for `sig_or_mi`. Will derive a rule if
`sig_or_mi` is not a primitive.

Set `skip_world_age_check=true` when the interpreter's world age is intentionally older
than the current world (e.g., when building rules for MistyClosure which uses its own world).
"""
function build_frule(
interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true
interp::MooncakeInterpreter{C},
sig_or_mi;
debug_mode=false,
silence_debug_messages=true,
skip_world_age_check=false,
) where {C}
@nospecialize sig_or_mi

# To avoid segfaults, ensure that we bail out if the interpreter's world age is greater
# than the current world age.
if Base.get_world_counter() > interp.world
if !skip_world_age_check && Base.get_world_counter() > interp.world
throw(
ArgumentError(
"World age associated to interp is behind current world age. Please " *
Expand Down Expand Up @@ -331,7 +367,11 @@ function modify_fwd_ad_stmts!(
if isexpr(stmt, :invoke) || isexpr(stmt, :call)
raw_args = isexpr(stmt, :invoke) ? stmt.args[2:end] : stmt.args
sig_types = map(raw_args) do x
return CC.widenconst(get_forward_primal_type(info.primal_ir, x))
t = CC.widenconst(get_forward_primal_type(info.primal_ir, x))
# Replace types containing Union{} (unreachable code/failed inference)
# with Any. This allows the code to proceed; is_primitive will return
# false and we'll use dynamic rules that resolve types at runtime.
return contains_bottom_type(t) ? Any : t
end
sig = Tuple{sig_types...}
mi = isexpr(stmt, :invoke) ? get_mi(stmt.args[1]) : missing
Expand Down
205 changes: 205 additions & 0 deletions src/rules/high_order_derivative_patches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Forward-mode primitive for _build_rule! on LazyDerivedRule.
# This avoids differentiating through get_interpreter which has a ccall to jl_get_world_counter.
# The tangent propagation happens through the fwds_oc MistyClosure call, not the rule building.
# Reverse-over-reverse is not supported; an rrule!! that throws is provided below.
@is_primitive MinimalCtx Tuple{typeof(_build_rule!),LazyDerivedRule,Tuple}

function frule!!(
::Dual{typeof(_build_rule!)},
lazy_rule_dual::Dual{<:LazyDerivedRule{sig}},
args_dual::Dual{<:Tuple},
) where {sig}
lazy_rule = primal(lazy_rule_dual)
lazy_tangent = tangent(lazy_rule_dual)
primal_args = primal(args_dual)
tangent_args = tangent(args_dual)

# Build rrule if not built (primal operation, no differentiation needed)
if !isdefined(lazy_rule, :rule)
interp = get_interpreter(ReverseMode)
lazy_rule.rule = build_rrule(interp, lazy_rule.mi; debug_mode=lazy_rule.debug_mode)
end
derived_rule = lazy_rule.rule

# Initialize the tangent of the derived rule if needed
rule_tangent_field = lazy_tangent.fields.rule
if !isdefined(rule_tangent_field, :tangent)
# Need to update the MutableTangent's fields with a new PossiblyUninitTangent
new_rule_tangent = PossiblyUninitTangent(zero_tangent(derived_rule))
lazy_tangent.fields = merge(lazy_tangent.fields, (; rule=new_rule_tangent))
rule_tangent_field = new_rule_tangent
end
derived_tangent = rule_tangent_field.tangent

# Forward-differentiate through the DerivedRule call.
# DerivedRule(args...) internally calls fwds_oc(args...) and returns (CoDual, Pullback)
fwds_oc = derived_rule.fwds_oc
fwds_oc_tangent = derived_tangent.fields.fwds_oc

# Handle varargs unflattening
isva = _isva(derived_rule)
nargs = derived_rule.nargs
N = length(primal_args)
uf_primal_args = __unflatten_codual_varargs(isva, primal_args, nargs)
uf_tangent_args = __unflatten_tangent_varargs(isva, tangent_args, nargs)

# Create dual args for frule!! call
dual_args = map(Dual, uf_primal_args, uf_tangent_args)

# Call frule!! on fwds_oc to get forward-differentiated result
dual_fwds_oc = Dual(fwds_oc, fwds_oc_tangent)
codual_result_dual = frule!!(dual_fwds_oc, dual_args...)

# Create Pullback and its tangent
pb_oc_ref = derived_rule.pb_oc_ref
pb_primal = Pullback(sig, pb_oc_ref, isva, N)
pb_tangent = Tangent((; pb_oc=derived_tangent.fields.pb_oc_ref))

# Return Dual of (CoDual, Pullback)
primal_result = (primal(codual_result_dual), pb_primal)
tangent_result = (tangent(codual_result_dual), pb_tangent)
return Dual(primal_result, tangent_result)
end

# Helper to unflatten tangent args similar to __unflatten_codual_varargs
function __unflatten_tangent_varargs(isva::Bool, tangent_args, ::Val{nargs}) where {nargs}
isva || return tangent_args
group_tangent = tangent_args[nargs:end]
return (tangent_args[1:(nargs - 1)]..., group_tangent)
end

# Reverse-over-reverse is not supported. Throw an informative error.
function rrule!!(
::CoDual{typeof(_build_rule!)}, ::CoDual{<:LazyDerivedRule}, ::CoDual{<:Tuple}
)
throw(
ArgumentError(
"Reverse-over-reverse differentiation is not supported. " *
"Encountered attempt to differentiate _build_rule! in reverse mode.",
),
)
end

# TODO: This is a workaround for forward-over-reverse. Primitives in reverse mode can get
# inlined when building the forward rule, exposing internal ccalls that lack an frule!!.
# For example, `dataids` is a reverse-mode primitive, but inlining it exposes
# `jl_genericmemory_owner`. The proper fix is to prevent primitive inlining during
# forward-over-reverse by forwarding `inlining_policy` through `BugPatchInterpreter` to
# `MooncakeInterpreter` during `optimise_ir!`, but this causes allocation regressions.
# See https://github.com/chalk-lab/Mooncake.jl/pull/878 for details.
@static if VERSION >= v"1.11-"
function frule!!(
::Dual{typeof(_foreigncall_)},
::Dual{Val{:jl_genericmemory_owner}},
::Dual{Val{Any}},
::Dual{Tuple{Val{Any}}},
::Dual{Val{0}},
::Dual{Val{:ccall}},
a::Dual{<:Memory},
)
return zero_dual(ccall(:jl_genericmemory_owner, Any, (Any,), primal(a)))
end
function rrule!!(
::CoDual{typeof(_foreigncall_)},
::CoDual{Val{:jl_genericmemory_owner}},
::CoDual{Val{Any}},
::CoDual{Tuple{Val{Any}}},
::CoDual{Val{0}},
::CoDual{Val{:ccall}},
a::CoDual{<:Memory},
)
y = zero_fcodual(ccall(:jl_genericmemory_owner, Any, (Any,), primal(a)))
return y, NoPullback(ntuple(_ -> NoRData(), 7))
end
end

# Avoid differentiating through AD infrastructure during second-order differentiation.
@zero_derivative MinimalCtx Tuple{
typeof(Core.kwcall),NamedTuple,typeof(prepare_gradient_cache),Vararg
}
@zero_derivative MinimalCtx Tuple{
typeof(Core.kwcall),NamedTuple,typeof(prepare_derivative_cache),Vararg
}
@zero_derivative MinimalCtx Tuple{
typeof(Core.kwcall),NamedTuple,typeof(prepare_pullback_cache),Vararg
}
@zero_derivative MinimalCtx Tuple{typeof(zero_tangent),Any}

@static if VERSION < v"1.11-"
@generated function frule!!(
::Dual{typeof(_foreigncall_)},
::Dual{Val{:jl_alloc_array_1d}},
::Dual{Val{Vector{P}}},
::Dual{Tuple{Val{Any},Val{Int}}},
::Dual{Val{0}},
::Dual{Val{:ccall}},
::Dual{Type{Vector{P}}},
n::Dual{Int},
args::Vararg{Dual},
) where {P}
T = tangent_type(P)
return quote
_n = primal(n)
y = ccall(:jl_alloc_array_1d, Vector{$P}, (Any, Int), Vector{$P}, _n)
dy = ccall(:jl_alloc_array_1d, Vector{$T}, (Any, Int), Vector{$T}, _n)
return Dual(y, dy)
end
end
@generated function frule!!(
::Dual{typeof(_foreigncall_)},
::Dual{Val{:jl_alloc_array_2d}},
::Dual{Val{Matrix{P}}},
::Dual{Tuple{Val{Any},Val{Int},Val{Int}}},
::Dual{Val{0}},
::Dual{Val{:ccall}},
::Dual{Type{Matrix{P}}},
m::Dual{Int},
n::Dual{Int},
args::Vararg{Dual},
) where {P}
T = tangent_type(P)
return quote
_m, _n = primal(m), primal(n)
y = ccall(:jl_alloc_array_2d, Matrix{$P}, (Any, Int, Int), Matrix{$P}, _m, _n)
dy = ccall(:jl_alloc_array_2d, Matrix{$T}, (Any, Int, Int), Matrix{$T}, _m, _n)
return Dual(y, dy)
end
end
@generated function frule!!(
::Dual{typeof(_foreigncall_)},
::Dual{Val{:jl_alloc_array_3d}},
::Dual{Val{Array{P,3}}},
::Dual{Tuple{Val{Any},Val{Int},Val{Int},Val{Int}}},
::Dual{Val{0}},
::Dual{Val{:ccall}},
::Dual{Type{Array{P,3}}},
l::Dual{Int},
m::Dual{Int},
n::Dual{Int},
args::Vararg{Dual},
) where {P}
T = tangent_type(P)
return quote
_l, _m, _n = primal(l), primal(m), primal(n)
y = ccall(
:jl_alloc_array_3d,
Array{$P,3},
(Any, Int, Int, Int),
Array{$P,3},
_l,
_m,
_n,
)
dy = ccall(
:jl_alloc_array_3d,
Array{$T,3},
(Any, Int, Int, Int),
Array{$T,3},
_l,
_m,
_n,
)
return Dual(y, dy)
end
end
end
1 change: 1 addition & 0 deletions src/rules/low_level_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
@from_chainrules MinimalCtx Tuple{typeof(deg2rad),IEEEFloat}
@from_chainrules MinimalCtx Tuple{typeof(rad2deg),IEEEFloat}
@from_chainrules MinimalCtx Tuple{typeof(^),P,P} where {P<:IEEEFloat}

@from_chainrules MinimalCtx Tuple{typeof(atan),P,P} where {P<:IEEEFloat}
@from_chainrules MinimalCtx Tuple{typeof(max),P,P} where {P<:IEEEFloat}
@from_chainrules MinimalCtx Tuple{typeof(min),P,P} where {P<:IEEEFloat}
Expand Down
29 changes: 29 additions & 0 deletions src/rules/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,24 @@ end

# _new_ and _new_-adjacent rules for Memory, MemoryRef, and Array.

@static if VERSION >= v"1.12-"
@is_primitive MinimalCtx Tuple{typeof(Core.memorynew),Type{<:Memory},Int}
function frule!!(
::Dual{typeof(Core.memorynew)}, ::Dual{Type{Memory{P}}}, n::Dual{Int}
) where {P}
x = Core.memorynew(Memory{P}, primal(n))
dx = Core.memorynew(Memory{tangent_type(P)}, primal(n))
return Dual(x, dx)
end
function rrule!!(
::CoDual{typeof(Core.memorynew)}, ::CoDual{Type{Memory{P}}}, n::CoDual{Int}
) where {P}
x = Core.memorynew(Memory{P}, primal(n))
dx = Core.memorynew(Memory{tangent_type(P)}, primal(n))
return CoDual(x, dx), NoPullback((NoRData(), NoRData(), NoRData()))
end
end

@is_primitive MinimalCtx Tuple{Type{<:Memory},UndefInitializer,Int}
function frule!!(::Dual{Type{Memory{P}}}, ::Dual{UndefInitializer}, n::Dual{Int}) where {P}
x = Memory{P}(undef, primal(n))
Expand Down Expand Up @@ -908,6 +926,17 @@ function hand_written_rule_test_cases(rng_ctor, ::Val{:memory})
zip(mem_refs, sample_mem_ref_values),
)
test_cases = vcat(
@static(
if VERSION >= v"1.12-"
[
(true, :stability, nothing, Core.memorynew, Memory{Float64}, 5),
(true, :stability, nothing, Core.memorynew, Memory{Float64}, 10),
(true, :stability, nothing, Core.memorynew, Memory{Int}, 5),
]
else
[]
end
),

# Rules for `Memory`
(true, :stability, nothing, Memory{Float64}, undef, 5),
Expand Down
Loading
Loading