From f23b8f620c2861e471e68783ecbf0de78c388d4f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 2 Nov 2025 19:36:19 +0000 Subject: [PATCH 01/19] add fixes --- src/interface.jl | 14 +++++ src/interpreter/forward_mode.jl | 59 +++++++++++++++++-- .../avoiding_non_differentiable_code.jl | 11 ++++ .../differentiation_interface.jl | 13 ++++ 4 files changed, 93 insertions(+), 4 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d2875148d6..98fbbd7b9d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -194,6 +194,14 @@ struct Cache{Trule,Ty_cache,Ttangents<:Tuple} tangents::Ttangents end +tangent_type(::Type{<:Cache}) = NoTangent + +@inline zero_tangent(x::Cache) = NoTangent() + +@inline zero_tangent_internal(::Cache, ::MaybeCache) = NoTangent() + +@inline randn_tangent_internal(::AbstractRNG, ::Cache, ::MaybeCache) = NoTangent() + """ __exclude_unsupported_output(y) __exclude_func_with_unsupported_output(fx) @@ -571,3 +579,9 @@ derivative of `primal(f)` at the primal values in `x` in the direction of the ta in `f` and `x`. """ value_and_derivative!!(rule::R, fx::Vararg{Dual,N}) where {R,N} = rule(fx...) + +# Avoid differentiating cache constructors in forward mode to prevent +# forward-over-reverse from descending into interpreter/caches. +@zero_derivative MinimalCtx Tuple{typeof(prepare_pullback_cache),Vararg} ForwardMode +@zero_derivative MinimalCtx Tuple{typeof(prepare_gradient_cache),Vararg} ForwardMode +@zero_derivative MinimalCtx Tuple{typeof(prepare_derivative_cache),Vararg} ForwardMode diff --git a/src/interpreter/forward_mode.jl b/src/interpreter/forward_mode.jl index 6d0b32361c..eabe0d6be4 100644 --- a/src/interpreter/forward_mode.jl +++ b/src/interpreter/forward_mode.jl @@ -183,6 +183,38 @@ end const ATTACH_AFTER = true const ATTACH_BEFORE = false +@inline contains_bottom_type(T) = _contains_bottom_type(T, Base.IdSet{Any}()) + +function _contains_bottom_type(T, seen::Base.IdSet{Any}) + T === Union{} && return true + if T isa Union + return _contains_bottom_type(T.a, seen) || _contains_bottom_type(T.b, seen) + elseif T isa TypeVar + if T in seen + return false + end + push!(seen, T) + return _contains_bottom_type(T.ub, seen) + elseif T isa UnionAll + if T in seen + return false + end + push!(seen, T) + return _contains_bottom_type(T.body, seen) + elseif T isa DataType + if T in seen + return false + end + push!(seen, T) + for p in T.parameters + _contains_bottom_type(p, seen) && return true + end + return false + else + return false + end +end + modify_fwd_ad_stmts!(::Nothing, ::IRCode, ::SSAValue, ::Vector{Any}, ::DualInfo) = nothing modify_fwd_ad_stmts!(::GotoNode, ::IRCode, ::SSAValue, ::Vector{Any}, ::DualInfo) = nothing @@ -264,6 +296,8 @@ end function modify_fwd_ad_stmts!( stmt::UpsilonNode, dual_ir::IRCode, ssa::SSAValue, captures::Vector{Any}, ::DualInfo ) + # In some compiler-generated UpsilonNodes the `val` field can be undefined; skip safely. + isdefined(stmt, :val) || return nothing if !(stmt.val isa Union{Argument,SSAValue}) stmt = UpsilonNode(uninit_dual(get_const_primal_value(stmt.val))) end @@ -287,8 +321,10 @@ end @static if isdefined(Core, :EnterNode) function modify_fwd_ad_stmts!( - ::Core.EnterNode, ::IRCode, ::SSAValue, ::Vector{Any}, ::DualInfo + ::Core.EnterNode, dual_ir::IRCode, ssa::SSAValue, ::Vector{Any}, ::DualInfo ) + # Drop typed exception-enter nodes from dual IR to avoid optimiser assertions + replace_call!(dual_ir, ssa, nothing) return nothing end end @@ -305,6 +341,15 @@ function modify_fwd_ad_stmts!( sig_types = map(raw_args) do x return CC.widenconst(get_forward_primal_type(info.primal_ir, x)) end + if any(contains_bottom_type, sig_types) + sig_strings = join(map(x -> sprint(show, x), sig_types), ", ") + raw_strings = join(map(x -> sprint(show, x), raw_args), ", ") + @debug "forward-mode bottom argument" sig_types = sig_strings raw_args = + raw_strings stmt + filtered = [pair for pair in zip(sig_types, raw_args) if pair[1] !== Union{}] + sig_types = map(first, filtered) + raw_args = map(last, filtered) + end sig = Tuple{sig_types...} mi = isexpr(stmt, :invoke) ? stmt.args[1] : missing args = map(__inc, raw_args) @@ -360,11 +405,17 @@ function modify_fwd_ad_stmts!( new_undef_inst = new_inst(Expr(:throw_undef_if_not, stmt.args[1], ssa)) CC.insert_node!(dual_ir, ssa, new_undef_inst, ATTACH_AFTER) elseif isexpr(stmt, :enter) - # Leave this node alone + # Drop exception-handling scaffolding from the dual IR. + replace_call!(dual_ir, ssa, nothing) elseif isexpr(stmt, :leave) - # Leave this node alone + replace_call!(dual_ir, ssa, nothing) elseif isexpr(stmt, :pop_exception) - # Leave this node alone + replace_call!(dual_ir, ssa, nothing) + elseif isexpr(stmt, :the_exception) + # Preserve the primal exception object but give it a zero tangent. + inst = CC.NewInstruction(get_ir(info.primal_ir, ssa)) + ex_ssa = CC.insert_node!(dual_ir, ssa, inst, ATTACH_BEFORE) + replace_call!(dual_ir, ssa, Expr(:call, zero_dual, ex_ssa)) else msg = "Expressions of type `:$(stmt.head)` are not yet supported in forward mode" throw(ArgumentError(msg)) diff --git a/src/rrules/avoiding_non_differentiable_code.jl b/src/rrules/avoiding_non_differentiable_code.jl index 5852f0b7bd..38ec8c8a9a 100644 --- a/src/rrules/avoiding_non_differentiable_code.jl +++ b/src/rrules/avoiding_non_differentiable_code.jl @@ -88,6 +88,17 @@ import Base.CoreLogging as CoreLogging } ) +# +# Avoid differentiating Mooncake's rule construction in forward mode +# This prevents forward-over-reverse from descending into kw-wrapper exceptions and caches. +# +@zero_derivative MinimalCtx Tuple{typeof(build_rrule),Vararg} ForwardMode +@zero_derivative MinimalCtx Tuple{typeof(Core.kwcall),NamedTuple,typeof(build_rrule),Vararg} ForwardMode + +# Avoid differentiating tangent and cache constructors in forward mode +@zero_derivative MinimalCtx Tuple{typeof(zero_tangent),Vararg} ForwardMode +@zero_derivative MinimalCtx Tuple{typeof(zero_tangent_internal),Vararg} ForwardMode + function hand_written_rule_test_cases(rng_ctor, ::Val{:avoiding_non_differentiable_code}) _x = Ref(5.0) _dx = Ref(4.0) diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index ea74e84b6d..dd0f9e986e 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -10,3 +10,16 @@ test_differentiation( excluded=SECOND_ORDER, logging=true, ) + +# Explicit second-order sanity tests for Mooncake forward-over-reverse +@testset "Mooncake second-order examples" begin + backend = SecondOrder(AutoMooncakeForward(), AutoMooncake()) + + # Sum: Hessian is zero + @test DI.hessian(sum, backend, [2.0]) == [0.0] + + # Rosenbrock 2D at [1.2, 1.2] + rosen(z) = (1.0 - z[1])^2 + 100.0 * (z[2] - z[1]^2)^2 + H = DI.hessian(rosen, backend, [1.2, 1.2]) + @test isapprox(H, [1250.0 -480.0; -480.0 200.0]; rtol=1e-10, atol=1e-12) +end From b1984f00d76bca2031c8b24fe8b3eb28ff0313cf Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 2 Nov 2025 20:01:02 +0000 Subject: [PATCH 02/19] add missing imports --- .../ext/differentiation_interface/differentiation_interface.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index dd0f9e986e..19ad454bc2 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -1,8 +1,9 @@ -using Pkg +using Pkg, Test Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using DifferentiationInterface, DifferentiationInterfaceTest +import DifferentiationInterface as DI using Mooncake: Mooncake test_differentiation( From 021e18570c5d9de31005eeac4443d5c514ca325c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 9 Nov 2025 11:44:30 +0000 Subject: [PATCH 03/19] add claude related files --- .claude/settings.local.json | 3 + CLAUDE.md | 7 + MOONCAKE_CODEBASE_MAP.md | 1003 +++++++++++++++++++++++++++++++++++ 3 files changed, 1013 insertions(+) create mode 100644 .claude/settings.local.json create mode 100644 CLAUDE.md create mode 100644 MOONCAKE_CODEBASE_MAP.md diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000000..165ebd5c83 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,3 @@ +{ + "model": "opus" +} diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..af7eb10b38 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,7 @@ +read `MOONCAKE_CODEBASE_MAP.md` to get an idea what are you working with + +read `test/runtests.jl` to understand how to run test before you run any test + +it is important that the code you write conform to the existing code in this repo + +the test infra of Mooncake is strict, and well structured, when you add tests, you must use existing infra as much as you can \ No newline at end of file diff --git a/MOONCAKE_CODEBASE_MAP.md b/MOONCAKE_CODEBASE_MAP.md new file mode 100644 index 0000000000..458770cbc3 --- /dev/null +++ b/MOONCAKE_CODEBASE_MAP.md @@ -0,0 +1,1003 @@ +# Mooncake.jl Codebase - Comprehensive Conceptual Map + +> **Documentation**: See docs/src/index.md for project overview, docs/src/tutorial.md for usage examples + +## Quick Start Guide + +### **First-Time Readers: Start Here** + +If you're new to the codebase, read in this order: +1. **docs/src/understanding_mooncake/introduction.md** - Prerequisites and what AD is +2. **docs/src/understanding_mooncake/algorithmic_differentiation.md** - Math foundations (skim if familiar with AD) +3. **docs/src/understanding_mooncake/rule_system.md** - Core rule interface (critical!) +4. **src/tangents.jl:1-100** - Skim to understand tangent types +5. **src/fwds_rvs_data.jl:1-100** - Skim to understand fdata/rdata split +6. **Pick a task below** and follow the relevant reading path + +### **Reading Paths by Goal** + +#### **Goal: Use Mooncake.jl** +1. docs/src/tutorial.md - DifferentiationInterface.jl usage +2. docs/src/interface.md - Native API if DI.jl insufficient +3. docs/src/known_limitations.md - What doesn't work + +#### **Goal: Debug an AD Issue** +1. docs/src/utilities/debugging_and_mwes.md - How to create MWEs +2. docs/src/utilities/debug_mode.md - Enable type checking +3. src/developer_tools.jl - Inspect generated IR +4. docs/src/developer_documentation/developer_tools.md - Using IR tools + +#### **Goal: Add a Primitive Rule** +1. docs/src/utilities/defining_rules.md - Complete guide +2. docs/src/understanding_mooncake/rule_system.md - Rule interface +3. Look at similar rule in src/rrules/ for pattern +4. src/test_utils.jl - Use `test_rule` for testing + +#### **Goal: Support a Custom Type** +1. docs/src/developer_documentation/custom_tangent_type.md - Complete guide +2. src/tangents.jl:302-496 - See default tangent_type implementation +3. ext/MooncakeDynamicExpressionsExt.jl - Complex real-world example +4. src/test_utils.jl - Use `test_data` for verification + +#### **Goal: Understand How AD Works Internally** +1. docs/src/understanding_mooncake/algorithmic_differentiation.md - Math +2. docs/src/developer_documentation/reverse_mode_design.md - Compilation overview +3. docs/src/developer_documentation/ir_representation.md - IR basics +4. src/interpreter/reverse_mode.jl:1152-1196 - Read `generate_ir` +5. src/interpreter/reverse_mode.jl:394-861 - Read `make_ad_stmts!` for one statement type + +#### **Goal: Fix a Performance Issue** +1. Profile to identify bottleneck +2. src/developer_tools.jl - Inspect generated IR +3. src/rrules/performance_patches.jl - See example performance rule +4. Consider adding primitive to DefaultCtx (not MinimalCtx) + +## Glossary of Terms + +| Term | Meaning | Location | +|------|---------|----------| +| **Primal** | Original computation being differentiated | Throughout | +| **Tangent** | Derivative information; input/output of `D f[x]` | src/tangents.jl:1-1426 | +| **Cotangent** | Adjoint/gradient information; input/output of `D f[x]*` | Throughout | +| **CoDual** | Pairs primal with fdata for reverse-mode | src/codual.jl:1-124 | +| **Dual** | Pairs primal with tangent for forward-mode | src/dual.jl:1-58 | +| **FData** | Forward data - mutable components of tangent | src/fwds_rvs_data.jl | +| **RData** | Reverse data - immutable components of tangent | src/fwds_rvs_data.jl | +| **Rule** | Function that computes AD (hand-written or derived) | Throughout | +| **Primitive** | Function with hand-written rule | src/interpreter/contexts.jl | +| **Derived Rule** | Auto-generated rule from IR transformation | src/interpreter/reverse_mode.jl | +| **Pullback** | Reverse-pass function that propagates gradients | Throughout | +| **IRCode** | Julia's SSA intermediate representation | Core.Compiler | +| **BBCode** | Mooncake's basic-block IR representation | src/interpreter/bbcode.jl | +| **SSA** | Static Single Assignment - each variable assigned once | Throughout | +| **PhiNode** | Merges values from different control flow paths | IR nodes | +| **Block Stack** | Tracks which blocks were visited (reverse-mode) | src/interpreter/reverse_mode.jl:81 | +| **Activity Analysis** | Determining what's differentiable vs constant | Implicit throughout | +| **SROA** | Scalar Replacement of Aggregates - compiler optimization | Julia compiler | + +## Data Flow Overview + +### **Reverse-Mode: value_and_gradient!! Call Chain** + +``` +User calls: value_and_gradient!!(rule, f, x...) + ↓ +interface.jl:169 → __value_and_gradient!!(rule, CoDual(f, tangent(f)), CoDual(x, tangent(x))...) + ↓ +interface.jl:104 → rule(map(to_fwds, coduals)...) # to_fwds extracts fdata + ↓ +FORWARD PASS (generated by build_rrule): + DerivedRule.fwds_oc(args...) runs, executing: + - Extract shared data from captures + - Create rdata Refs (optimized away by SROA) + - Run transformed statements (rrule!! calls) + - Push to block stack when needed + - Push intermediate values to communication stacks + - Return: CoDual(result, fdata), pullback + ↓ +interface.jl:110 → pullback(one(result)) # or pullback(rdata(ȳ)) + ↓ +REVERSE PASS (generated by build_rrule): + Pullback.pb_oc(dy) runs, executing: + - Pop communication stacks + - Pop block stack to determine control flow + - Run transformed statements in reverse + - Increment rdata Refs + - Return: tuple of rdata for all arguments + ↓ +interface.jl:110 → tangent(fdata, rdata) for each argument + ↓ +Return: (value, (NoTangent(), gradient_x1, gradient_x2, ...)) +``` + +### **Where Time is Spent (Performance Model)** + +1. **First call**: Rule compilation dominates (~90%+ of time) + - `build_rrule` generates and optimizes IR + - Stored in `interp.oc_cache` for reuse + +2. **Subsequent calls**: Execution time depends on: + - **Primitives** (~fast) - Hand-written rules are typically well-optimized + - **Block stack ops** (~10-30% overhead) - Reduced via unique predecessor optimization + - **Communication stacks** (~5-15% overhead) - Reduced via SingletonStack optimization + - **Memory operations** (~fast) - SROA eliminates most Ref allocations + +3. **Slowest operations** (in derived rules): + - Dynamic dispatch (`DynamicDerivedRule`) + - Type-unstable code + - Lots of small functions without primitives + - Large loops with value-dependent control flow + +## Critical Files (Read These First) + +### **Core Type System (Essential)** +1. **src/tangents.jl** - Tangent type system (~1400 lines, skim structure) + - Lines 1-100: Type definitions + - Lines 302-496: `tangent_type` implementation + - Lines 508-605: `zero_tangent_internal` implementation + +2. **src/fwds_rvs_data.jl** - FData/RData split (~1000 lines, skim) + - Lines 1-100: FData/RData type definitions + - Lines 155-202: `fdata_type` implementation + - Lines 433-476: `rdata_type` implementation + +3. **src/codual.jl** - CoDual type (~120 lines, read fully) + +### **Rule System (Essential)** +4. **src/interface.jl** - Public API (read lines 1-200 fully) + - Understanding `value_and_gradient!!` flow is critical + +5. **src/interpreter/contexts.jl** - What makes something primitive (~120 lines, read fully) + +### **Rule Derivation (Important)** +6. **src/interpreter/reverse_mode.jl** - Read selectively: + - Lines 1-200: Data structures (SharedDataPairs, ADInfo, etc.) + - Lines 394-500: `make_ad_stmts!` for one example (e.g., ReturnNode) + - Lines 1044-1196: `build_rrule` and `generate_ir` flow + +7. **src/interpreter/bbcode.jl** - BBCode representation (skim if working on IR) + +### **Example Rules (Learn by Example)** +8. **src/rrules/low_level_maths.jl** - Simple rules (read lines 1-200) +9. **src/rrules/blas.jl** - Complex rules (read one example, e.g., gemm!) +10. **src/rrules/builtins.jl** - Essential primitives (skim IntrinsicsWrappers module) + +## 1. **Core Architecture** + +### **Main Entry Point** +- **src/Mooncake.jl:1-172** - Module definition, includes all submodules, defines core functions (`frule!!`, `rrule!!`, `build_primitive_rrule`) +- **Documentation**: docs/src/index.md (getting started), docs/src/tutorial.md (DifferentiationInterface.jl usage) + +### **Type System Hierarchy** + +#### **Primal-Tangent Pairing** +- **src/dual.jl:1-58** - `Dual{P,T}` for forward-mode (primal + tangent) +- **src/codual.jl:1-124** - `CoDual{Tx,Tdx}` for reverse-mode (primal + fdata) + +#### **Tangent Types** +- **src/tangents.jl:1-1426** - Core tangent system: + - `NoTangent` - for non-differentiable types + - `Tangent{NamedTuple}` - for immutable structs + - `MutableTangent{NamedTuple}` - for mutable structs + - `PossiblyUninitTangent{T}` - for potentially undefined fields + - Functions: `tangent_type`, `zero_tangent`, `randn_tangent`, `increment!!`, `set_to_zero!!` + +#### **FData/RData Splitting** +- **src/fwds_rvs_data.jl:1-1026** - Separates tangents into: + - **FData** (forward data) - mutable/address-identified components, propagated on forward pass + - **RData** (reverse data) - immutable/value-identified components, propagated on reverse pass + - Key functions: `fdata_type`, `rdata_type`, `fdata`, `rdata`, `tangent(f, r)` + - `ZeroRData` handling in **src/interpreter/zero_like_rdata.jl:1-40** +- **Documentation**: docs/src/understanding_mooncake/rule_system.md - "Representing Gradients" section explains fdata/rdata design + +## 2. **Rule System** + +> **Documentation**: docs/src/understanding_mooncake/rule_system.md - complete rule interface specification + +### **Rule Interface** +- **rrule!! signature**: `(::CoDual{typeof(f)}, args::CoDual...) -> (CoDual{output}, pullback_function)` +- **frule!! signature**: `(::Dual{typeof(f)}, args::Dual...) -> Dual{output}` +- **Documentation**: + - docs/src/understanding_mooncake/rule_system.md - "The Rule Interface" sections + - docs/src/utilities/defining_rules.md - how to write custom rules + +### **Primitive Rules** (src/rrules/) +Hand-written rules for Julia primitives: + +| Category | File | Key Functions | +|----------|------|---------------| +| **Builtins** | src/rrules/builtins.jl:1-1000+ | `getfield`, `setfield!`, `tuple`, `===`, `isa`, `typeof`, `svec`, `ifelse` | +| **Intrinsics** | Module `IntrinsicsWrappers` in builtins.jl | `add_float`, `mul_float`, `div_float`, `neg_float`, `sqrt_llvm`, `fma_float`, `bitcast` | +| **Foreign calls** | src/rrules/foreigncall.jl:1-473 | `_foreigncall_`, `pointer_from_objref`, `unsafe_pointer_to_objref`, `unsafe_copyto!` | +| **Construction** | src/rrules/new.jl:1-212 | `_new_` (all object construction), `_splat_new_` | +| **Arrays (1.10)** | src/rrules/array_legacy.jl:1-666 | `arrayref`, `arrayset`, `_deletebeg!`, `_deleteend!`, `_growend!` | +| **Memory (1.11+)** | src/rrules/memory.jl:1-1000+ | `Memory`, `MemoryRef`, `memoryrefget`, `memoryrefset!`, `Array` construction | +| **BLAS** | src/rrules/blas.jl:1-1000+ | `gemm!`, `gemv!`, `symm!`, `symv!`, `trmv!`, `dot`, `nrm2`, `scal!`, `syrk!` | +| **LAPACK** | src/rrules/lapack.jl:1-643 | `getrf!`, `getrs!`, `getri!`, `trtrs!`, `potrf!`, `potrs!` | +| **Linear Algebra** | src/rrules/linear_algebra.jl:1-52 | `exp(::Matrix)` | +| **Low-level Math** | src/rrules/low_level_maths.jl:1-305 | `exp`, `log`, `sin`, `cos`, `tan`, `sqrt`, `cbrt`, `hypot`, etc. | +| **FastMath** | src/rrules/fastmath.jl:1-162 | Fast versions of math functions | +| **TwicePrecision** | src/rrules/twice_precision.jl:1-524 | `TwicePrecision`, `StepRangeLen`, range operations | +| **Random** | src/rrules/random.jl:1-84 | `randn`, `randexp`, `MersenneTwister` | +| **Tasks** | src/rrules/tasks.jl:1-146 | `Task`, `current_task` (limited support) | +| **IdDict** | src/rrules/iddict.jl:1-247 | `IdDict` operations | +| **MistyClosure** | src/rrules/misty_closures.jl:1-162 | Differentiation of closures with captured variables | +| **Performance** | src/rrules/performance_patches.jl:1-72 | Optimized `sum` for arrays | +| **Misc** | src/rrules/misc.jl:1-398 | `lgetfield`, `lsetfield!`, logging, string ops | +| **Avoidance** | src/rrules/avoiding_non_differentiable_code.jl:1-225 | Pointer arithmetic, logging macros, `@zero_derivative` rules | + +### **Rule Derivation** (Automatic Differentiation) + +> **Documentation**: +> - docs/src/developer_documentation/forwards_mode_design.md - forward-mode internals (planned) +> - docs/src/developer_documentation/reverse_mode_design.md - reverse-mode compilation process +> - docs/src/understanding_mooncake/algorithmic_differentiation.md - mathematical foundations + +#### **Forward Mode** +- **src/interpreter/forward_mode.jl:1-507** - Derives `frule!!` from IR: + - `build_frule` - main entry point + - `generate_dual_ir` - transforms IR to compute derivatives + - `DerivedFRule` - wrapper for derived forward rules + - `LazyFRule` - lazy rule construction + - `DynamicFRule` - dynamic dispatch +- **Documentation**: docs/src/developer_documentation/forwards_mode_design.md - detailed design document + +#### **Reverse Mode** +- **src/interpreter/reverse_mode.jl:1-1875** - Derives `rrule!!` from IR: + - `build_rrule` - main entry point (reverse_mode.jl:1044-1144) + - `generate_ir` - creates forward + reverse IR (reverse_mode.jl:1152-1196) + - `make_ad_stmts!` - transforms each IR statement (reverse_mode.jl:394-861) + - `DerivedRule` - wrapper for derived reverse rules (reverse_mode.jl:934-960) + - `Pullback` - callable that runs reverse pass (reverse_mode.jl:918-932) + - `LazyDerivedRule` - lazy rule construction (reverse_mode.jl:1816-1842) + - `DynamicDerivedRule` - dynamic dispatch (reverse_mode.jl:1726-1752) + - `SharedDataPairs` - manages captured data (reverse_mode.jl:13-72) + - `ADInfo` - global context for rule derivation (reverse_mode.jl:123-206) + - `BlockStack` - tracks control flow (reverse_mode.jl:81) +- **Documentation**: + - docs/src/developer_documentation/reverse_mode_design.md - compilation overview + - docs/src/understanding_mooncake/algorithmic_differentiation.md - "Reverse-Mode AD: how does it do it?" + +## 3. **IR Manipulation** + +> **Documentation**: docs/src/developer_documentation/ir_representation.md - comprehensive guide to IRCode vs BBCode + +### **IR Representations** +- **src/interpreter/bbcode.jl:1-1010** - `BBCode` data structure: + - `BBlock` - basic block with unique IDs (bbcode.jl:194-265) + - `ID` - unique identifier for blocks/statements (bbcode.jl:79-86) + - `IDPhiNode`, `IDGotoNode`, `IDGotoIfNot` - ID-based control flow (bbcode.jl:108-138) + - `Switch` - multi-way branch statement (bbcode.jl:160-168) + - Conversion: `BBCode(::IRCode)` and `IRCode(::BBCode)` (bbcode.jl:528-657) +- **Documentation**: docs/src/developer_documentation/ir_representation.md + - Julia's SSA-form IR explanation + - Control flow and PhiNodes + - Code transformation examples (replacing instructions, inserting blocks) + - When to use IRCode vs BBCode + +### **IR Utilities** +- **src/interpreter/ir_utils.jl:1-334** - IR manipulation: + - `stmt` - get statement from IR (ir_utils.jl:9) + - `set_stmt!`, `get_ir`, `set_ir!`, `replace_call!` (ir_utils.jl:18-36) + - `ircode` - construct IRCode for testing (ir_utils.jl:56-65) + - `infer_ir!` - run type inference (ir_utils.jl:101-126) + - `optimise_ir!` - optimization pipeline (ir_utils.jl:146-188) + - `lookup_ir` - get IR from signature/MethodInstance (ir_utils.jl:206-254) + +### **IR Normalization** +- **src/interpreter/ir_normalisation.jl:1-495** - Standardize IR: + - `normalise!` - main entry (ir_normalisation.jl:23-44) + - `foreigncall_to_call` - `:foreigncall` → `_foreigncall_()` call (ir_normalisation.jl:144-158) + - `new_to_call` - `:new` → `_new_()` call (ir_normalisation.jl:218) + - `splatnew_to_call` - `:splatnew` → `_splat_new_()` call (ir_normalisation.jl:229) + - `intrinsic_to_function` - intrinsics → `IntrinsicsWrappers` (ir_normalisation.jl:244-256) + - `lift_getfield_and_others` - constant field access → `lgetfield` (ir_normalisation.jl:267-290) + - `lift_gc_preservation` - GC preservation handling (ir_normalisation.jl:403-407) + - `const_prop_gotoifnots!` - constant propagation for branches (ir_normalisation.jl:416-432) + +### **Compiler Integration** +- **src/interpreter/abstract_interpretation.jl:1-223** - Custom interpreter: + - `MooncakeInterpreter{C,M}` - subtype of `AbstractInterpreter` (abstract_interpretation.jl:27-66) + - `ClosureCacheKey` - cache key for closures (abstract_interpretation.jl:13-16) + - `inlining_policy` - prevents primitive inlining (abstract_interpretation.jl:159-196) + - `get_interpreter` - returns cached interpreter (abstract_interpretation.jl:217-222) + - `GLOBAL_INTERPRETERS` - cached interpreters (abstract_interpretation.jl:204-207) +- **Documentation**: docs/src/developer_documentation/reverse_mode_design.md - explains `MooncakeInterpreter` role + +### **Contexts** +- **src/interpreter/contexts.jl:1-119** - AD contexts: + - `MinimalCtx` - only essential primitives (contexts.jl:8) + - `DefaultCtx` - all performance primitives (contexts.jl:17) + - `ForwardMode`, `ReverseMode` - AD mode markers (contexts.jl:32, 39) + - `is_primitive` - determines if function is primitive (contexts.jl:58-61) + - `@is_primitive` - macro to declare primitives (contexts.jl:69-118) +- **Documentation**: docs/src/developer_documentation/reverse_mode_design.md - distinction between MinimalCtx and DefaultCtx + +### **Compiler Patches** +- **src/interpreter/patch_for_319.jl:1-435** - Workarounds for Julia compiler bugs (issue #319) + +## 4. **Public Interface** + +> **Documentation**: docs/src/interface.md - complete public API reference + +- **src/interface.jl:1-588** - User-facing API: + - `value_and_gradient!!(rule, f, x...)` - compute gradient (interface.jl:169-171) + - `value_and_pullback!!(rule, ȳ, f, x...)` - compute pullback (interface.jl:142-144) + - `value_and_derivative!!(rule, f, x...)` - forward-mode (interface.jl:581) + - `prepare_gradient_cache` - pre-compile for performance (interface.jl:515-522) + - `prepare_pullback_cache` - pre-compile for pullback (interface.jl:439-458) + - `prepare_derivative_cache` - pre-compile for forward-mode (interface.jl:572) + - `__value_and_gradient!!`, `__value_and_pullback!!` - lower-level internal versions +- **Documentation**: + - docs/src/interface.md - public API docstrings + - docs/src/tutorial.md - usage examples with DifferentiationInterface.jl + +- **src/config.jl:1-18** - Configuration: + - `Config(; debug_mode, silence_debug_messages)` +- **Documentation**: docs/src/utilities/debug_mode.md - when and how to use debug mode + +- **src/public.jl:1-15** - Public API macro for Julia 1.11+ + +## 5. **Utilities & Tools** + +### **Rule Definition Helpers** +- **src/tools_for_rules.jl:1-698** - Macros and utilities: + - `@mooncake_overlay` - override function for AD (tools_for_rules.jl:104-112) + - `@zero_derivative` - mark functions with zero derivative (tools_for_rules.jl:248-302) + - `@zero_adjoint` - reverse-mode specific (tools_for_rules.jl:310-312) + - `@from_chainrules` - import ChainRules rrules (tools_for_rules.jl:628-687) + - `@from_rrule` - import specific rrule (tools_for_rules.jl:695-697) + - `zero_adjoint`, `zero_derivative` - function versions (tools_for_rules.jl:148-177) + - `to_cr_tangent`, `mooncake_tangent` - ChainRules conversion (tools_for_rules.jl:323-373) +- **Documentation**: docs/src/utilities/defining_rules.md + - Complete guide to all rule-writing strategies + - `@mooncake_overlay` examples + - `@zero_adjoint` usage + - `@from_rrule` / `@from_chainrules` with worked examples + - When to implement custom `rrule!!` + +### **Testing Infrastructure** +- **src/test_utils.jl:1-1680** - Comprehensive testing: + - `test_rule` - main testing function (test_utils.jl:895-986) + - `test_tangent_interface` - test tangent operations (test_utils.jl:1111-1233) + - `test_tangent_splitting` - test fdata/rdata split (test_utils.jl:1439-1522) + - `test_rule_and_type_interactions` - test primitives work (test_utils.jl:1553-1580) + - `test_data` - combined test (test_utils.jl:1672-1677) + - `has_equal_data` - structural equality (test_utils.jl:201-327) + - `populate_address_map` - track aliasing (test_utils.jl:338-414) +- **Documentation**: + - docs/src/utilities/debugging_and_mwes.md - using `test_rule` for debugging + - docs/src/developer_documentation/tangents.md - testing functions explained + - docs/src/developer_documentation/running_tests_locally.md - local testing workflow + +- **src/test_resources.jl:1-995** - Test data: + - Module `TestResources` with test types (test_resources.jl:8-989) + - `generate_test_functions` - standard test cases (test_resources.jl:699-929) + - Test types: `StructFoo`, `MutableFoo`, `TypeStableMutableStruct`, etc. + +### **Developer Tools** +- **src/developer_tools.jl:1-155** - IR inspection: + - `primal_ir` - get primal IR (developer_tools.jl:22-24) + - `dual_ir` - get forward-mode IR (developer_tools.jl:61-68) + - `fwd_ir` - get forward pass IR (developer_tools.jl:104-111) + - `rvs_ir` - get reverse pass IR (developer_tools.jl:147-154) +- **Documentation**: docs/src/developer_documentation/developer_tools.md - IR inspection guide + +### **General Utilities** +- **src/utils.jl:1-457** - Helper functions: + - `_typeof` - stable typeof (utils.jl:6-8) + - `tuple_map` - specialized tuple mapping (utils.jl:26-50) + - `always_initialised` - field initialization info (utils.jl:218-222) + - `lgetfield` - literal getfield with `Val` (utils.jl:250) + - `lsetfield!` - literal setfield with `Val` (utils.jl:261) + - `_new_` - direct `:new` instruction (utils.jl:268-270) + - `opaque_closure`, `misty_closure` - closure construction (utils.jl:326-367) + +### **Data Structures** +- **src/stack.jl:1-40** - Specialized stack: + - `Stack{T}` - never-deallocating stack for reverse pass (stack.jl:8-34) + - `SingletonStack{T}` - zero-overhead singleton stack (stack.jl:36-39) + +### **Debug Mode** +- **src/debug_mode.jl:1-124** - Runtime type checking: + - `DebugRRule` - wraps rules with type checks (debug_mode.jl:77-95) + - `DebugPullback` - wraps pullbacks with type checks (debug_mode.jl:14-31) + - `DebugFRule` - forward-mode equivalent (debug_mode.jl:2) +- **Documentation**: docs/src/utilities/debug_mode.md + - When to use debug mode + - How it catches type errors + - Performance implications + +## 6. **Extensions** (ext/) + +| Extension | Purpose | Key Types/Functions | +|-----------|---------|---------------------| +| **MooncakeCUDAExt.jl** | CUDA support | `CuArray` tangent ops, allocation rules | +| **MooncakeAllocCheckExt.jl** | Allocation checking | `check_allocs_internal` | +| **MooncakeJETExt.jl** | Type stability | `test_opt_internal`, `report_opt_internal` | +| **MooncakeLuxLibExt.jl** | LuxLib ops | `matmul`, `conv`, `batchnorm` overlays | +| **MooncakeLuxLibSLEEFPiratesExtension.jl** | Fast activations | `sigmoid_fast`, `tanh_fast`, etc. | +| **MooncakeNNlibExt.jl** | NNlib ops | `conv`, `pooling`, `dropout`, `softmax` | +| **MooncakeSpecialFunctionsExt.jl** | Special functions | Bessel, gamma, erf functions via ChainRules | +| **MooncakeFluxExt.jl** | Flux support | Optimized `mse` loss | +| **MooncakeFunctionWrappersExt.jl** | FunctionWrapper | Custom tangent with AD through wrapper | +| **MooncakeDynamicExpressionsExt.jl** | Symbolic expressions | `TangentNode` for expression trees | + +## 7. **Key Algorithms & Concepts** + +### **Forward-Mode AD Workflow** +1. `build_frule(sig)` → generates rule +2. `lookup_ir` → get primal IR +3. `normalise!` → standardize IR +4. Transform each statement: Arguments+1, wrap in `Dual`, replace calls with `frule!!` +5. `optimise_ir!` → optimize generated IR +6. Wrap in `MistyClosure` → `DerivedFRule` + +### **Reverse-Mode AD Workflow** +1. `build_rrule(sig)` → generates rule +2. `lookup_ir` → get primal IR +3. `normalise!` → standardize IR +4. `BBCode(ir)` → convert to basic block form +5. `ADInfo` construction → setup metadata +6. `make_ad_stmts!` → transform each statement into forward/reverse instructions +7. `forwards_pass_ir` → generate forward pass (reverse_mode.jl:1294-1356) +8. `pullback_ir` → generate reverse pass (reverse_mode.jl:1376-1523) +9. `optimise_ir!` → optimize both passes +10. Wrap in `MistyClosure`s → `DerivedRule` + `Pullback` + +### **Critical Implementation Details** + +**Block Stack** (reverse_mode.jl:81): +- Tracks which blocks were visited during forward pass +- Used to determine control flow on reverse pass +- Optimized away for unique predecessors + +**Shared Data** (reverse_mode.jl:13-72): +- Data shared between forward/reverse passes +- Stored in `OpaqueClosure` captures +- Contains block stack, communication stacks, lazy zero rdata + +**Communication Channels** (reverse_mode.jl:1251-1287): +- Per-block stacks storing intermediate values +- Push on forward pass, pop on reverse pass +- Optimized to `SingletonStack` when possible + +**RData Refs** (reverse_mode.jl:259-273): +- Each SSA/Argument gets a `Ref` to accumulate gradients +- Initialized to zero, incremented during reverse pass +- Optimized away by SROA pass + +## 8. **File Organization Summary** + +``` +Mooncake.jl/ +├── src/ +│ ├── Mooncake.jl # Main module +│ ├── tangents.jl # Tangent type system +│ ├── fwds_rvs_data.jl # FData/RData splitting +│ ├── dual.jl # Forward-mode Dual type +│ ├── codual.jl # Reverse-mode CoDual type +│ ├── interface.jl # Public API +│ ├── config.jl # Configuration +│ ├── debug_mode.jl # Runtime type checking +│ ├── stack.jl # Block stack +│ ├── utils.jl # General utilities +│ ├── tools_for_rules.jl # Rule definition macros +│ ├── test_utils.jl # Testing infrastructure +│ ├── test_resources.jl # Test data +│ ├── developer_tools.jl # IR inspection tools +│ ├── public.jl # Public API declarations +│ ├── interpreter/ +│ │ ├── contexts.jl # AD contexts +│ │ ├── abstract_interpretation.jl # Custom interpreter +│ │ ├── bbcode.jl # BBCode IR representation +│ │ ├── ir_utils.jl # IR manipulation +│ │ ├── ir_normalisation.jl # IR standardization +│ │ ├── forward_mode.jl # Forward-mode derivation +│ │ ├── reverse_mode.jl # Reverse-mode derivation +│ │ ├── zero_like_rdata.jl # ZeroRData utilities +│ │ └── patch_for_319.jl # Compiler bug workarounds +│ └── rrules/ +│ ├── builtins.jl # Core built-in functions +│ ├── foreigncall.jl # ccall handling +│ ├── new.jl # Object construction +│ ├── misc.jl # lgetfield, lsetfield!, etc. +│ ├── blas.jl # BLAS operations +│ ├── lapack.jl # LAPACK operations +│ ├── linear_algebra.jl # High-level LinAlg +│ ├── low_level_maths.jl # Math functions +│ ├── fastmath.jl # FastMath functions +│ ├── array_legacy.jl # Array ops (1.10) +│ ├── memory.jl # Memory/Array ops (1.11+) +│ ├── random.jl # Random number generation +│ ├── twice_precision.jl # TwicePrecision/ranges +│ ├── tasks.jl # Task (limited) +│ ├── iddict.jl # IdDict operations +│ ├── misty_closures.jl # Closure differentiation +│ ├── performance_patches.jl # Performance optimizations +│ ├── avoiding_non_differentiable_code.jl # Zero derivative rules +│ └── dispatch_doctor.jl # DispatchDoctor integration +├── ext/ # Package extensions +├── test/ # Test suite (mirrors src/) +└── docs/ # Documentation +``` + +## 9. **Documentation Structure** (docs/src/) + +### **User-Facing** +- `index.md` - Project overview, getting started, project status +- `tutorial.md` - DifferentiationInterface.jl and native API usage +- `interface.md` - Public API documentation +- `known_limitations.md` - Mutation of globals, recursive types, pointers + +### **Understanding Mooncake** +- `understanding_mooncake/introduction.md` - Prerequisites, who docs are for +- `understanding_mooncake/algorithmic_differentiation.md` - Mathematical foundations + - Fréchet derivatives, adjoints, tangents + - Forward vs reverse mode + - Gradients and directional derivatives +- `understanding_mooncake/rule_system.md` - Core rule interface + - `rrule!!` specification + - CoDual, fdata/rdata system + - Testing with `test_rule` + +### **Utilities** +- `utilities/defining_rules.md` - How to write rules + - `@mooncake_overlay` - code simplification + - `@zero_adjoint` - zero derivative functions + - `@from_rrule` - import ChainRules + - Adding custom `rrule!!` methods +- `utilities/debug_mode.md` - `DebugRRule` for type checking +- `utilities/debugging_and_mwes.md` - Using `TestUtils.test_rule` + +### **Developer Documentation** +- `developer_documentation/running_tests_locally.md` - Test workflow with Revise.jl +- `developer_documentation/developer_tools.md` - IR inspection tools +- `developer_documentation/tangents.md` - Tangent type interface requirements +- `developer_documentation/custom_tangent_type.md` - Detailed guide for recursive types + - Complete worked example with `struct A` containing self-reference + - All required methods: `zero_tangent_internal`, `randn_tangent_internal`, `increment_internal!!`, etc. +- `developer_documentation/ir_representation.md` - IRCode vs BBCode + - SSA-form IR explanation + - Control flow and PhiNodes + - Code transformation examples +- `developer_documentation/forwards_mode_design.md` - Forward-mode AD design (unimplemented) +- `developer_documentation/reverse_mode_design.md` - Compilation process overview +- `developer_documentation/misc_internals_notes.md` - Implementation notes + - `tangent_type` generated function design + - Recursion handling via `LazyDerivedRule` + +## 10. **Testing** (test/) + +### **Test Organization** +- **test/front_matter.jl:1-163** - Common test setup, determines test group +- **test/runtests.jl:1-71** - Main test runner with group selection +- **test/run_extra.jl:1-4** - Integration test runner + +### **Test Files** (mirror src/) +- `test/tangents.jl` - Tangent type tests +- `test/codual.jl` - CoDual tests +- `test/interface.jl` - Public API tests +- `test/rrules/*.jl` - Tests for each rrules file +- `test/interpreter/*.jl` - Tests for interpreter components +- `test/integration_testing/` - Integration with other packages +- `test/ext/` - Extension tests + +## 11. **Worked Example: Tracing a Simple Gradient** + +Let's trace `value_and_gradient!!(rule, f, x)` where `f(x) = sin(x)` and `x = 5.0`: + +### **Preparation** (one-time cost) +```julia +rule = build_rrule(f, 5.0) # In src/interpreter/reverse_mode.jl:1044 +``` +1. `lookup_ir` gets IR for `f(::Float64)` → finds `sin` call +2. `is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(sin), Float64})` → `true` (from low_level_maths.jl) +3. Since it's primitive, `build_rrule` returns `rrule!!` directly (no IR transformation needed) + +### **Execution** (fast, repeated calls) +```julia +value_and_gradient!!(rule, f, 5.0) # In src/interface.jl:169 +``` +1. Creates `CoDual(f, NoTangent())` and `CoDual(5.0, 0.0)` (zero tangent) +2. Calls `rule(to_fwds(coduals)...)` → `rrule!!(CoDual(f, NoFData()), CoDual(5.0, NoFData()))` +3. **Forward pass** (in src/rrules/low_level_maths.jl): + - Computes `y = sin(5.0) = -0.9589...` + - Returns `CoDual(-0.9589, NoFData())` and pullback closure +4. **Reverse pass** - calls `pullback(1.0)`: + - Computes `∂x = cos(5.0) * 1.0 = 0.2836...` + - Returns `(NoRData(), 0.2836...)` +5. Reconstructs result: `(-0.9589, (NoTangent(), 0.2836))` + +### **Worked Example: Derived Rule** + +For `g(x) = sin(cos(x))` with `x = 5.0`: + +**build_rrule generates**: +```julia +# Forward pass IR (simplified): +%1 = rrule!!(zero_fcodual(cos), CoDual(x, NoFData())) # y=%1[1], pb_cos=%1[2] +%2 = rrule!!(zero_fcodual(sin), %1[1]) # result=%2[1], pb_sin=%2[2] +push!(comms_stack, %1[2]) # Store cos's pullback +push!(comms_stack, %2[2]) # Store sin's pullback +return %2[1] # Return CoDual(sin(cos(x)), NoFData()) + +# Reverse pass IR (simplified): +pb_sin = pop!(comms_stack) +pb_cos = pop!(comms_stack) +dy_inner = pb_sin(dy) # dy w.r.t cos(x) +dx = pb_cos(dy_inner[2]) # dy w.r.t x +return (NoRData(), dx[2]) # (df, dx) +``` + +## 12. **Key Concepts & Terminology** + +> **Documentation**: docs/src/understanding_mooncake/algorithmic_differentiation.md - complete mathematical treatment + +### **Mathematical Foundation** +- **Fréchet Derivative**: Linear operator `D f[x] : X → Y` satisfying `df = D f[x](dx)` +- **Adjoint**: Linear operator `D f[x]* : Y → X` satisfying `⟨D f[x]*(ȳ), ẋ⟩ = ⟨ȳ, D f[x](ẋ)⟩` +- **Tangent**: Input/output of derivatives (denoted with dot: `ẋ`, `ẏ`) +- **Cotangent/Gradient**: Input/output of adjoints (denoted with bar: `x̄`, `ȳ`) +- **Documentation**: docs/src/understanding_mooncake/algorithmic_differentiation.md + - Derivatives section - scalar to general Hilbert spaces + - Worked examples with matrices and mutable data + - Chain rule and how forward/reverse mode work + - Directional derivatives and gradients + +### **Implementation Details** +- **Primitive**: Function with hand-written rule, not derived from IR +- **Derived Rule**: Rule automatically generated from IR transformation +- **Activity Analysis**: Distinguishing differentiable from non-differentiable data +- **Unique Predecessor**: Block with only one possible incoming edge (optimization opportunity) +- **Lazy Rule Construction**: Defer rule compilation until first use (handles recursion) +- **Dynamic Dispatch**: Runtime rule selection based on argument types +- **Static Dispatch**: Compile-time rule selection via `:invoke` nodes + +### **Performance Optimizations** +1. **SROA** (Scalar Replacement of Aggregates) - eliminates `Ref` allocations +2. **Singleton Type Optimization** - avoids storing singleton pullbacks +3. **Block Stack Elimination** - skips block tracking for unique predecessors +4. **Constant Propagation** - folds constants, eliminates dead branches +5. **Lazy Zero RData** - defers zero construction until needed + +## 13. **Common Patterns & Idioms** + +### **Pattern: Accessing Rule-Generated IR** +```julia +# See what IR Mooncake generates: +using Mooncake: get_interpreter, ReverseMode, primal_ir, fwd_ir, rvs_ir + +sig = Tuple{typeof(sin), Float64} +interp = get_interpreter(ReverseMode) + +primal_ir(interp, sig) # Original function IR +fwd_ir(sig) # Forward pass IR +rvs_ir(sig) # Reverse pass IR (pullback) +``` + +### **Pattern: CoDual Construction** +```julia +# Manual CoDual construction (rarely needed): +x = [1.0, 2.0, 3.0] +dx = zero_tangent(x) # Create tangent +codual = CoDual(x, fdata(dx)) # Pair primal with fdata + +# Usually automatic: +value_and_gradient!!(rule, f, x) # Handles CoDual construction internally +``` + +### **Pattern: Checking if Something is Primitive** +```julia +using Mooncake: is_primitive, DefaultCtx, ReverseMode + +sig = Tuple{typeof(sin), Float64} +is_primitive(DefaultCtx, ReverseMode, sig) # true - has hand-written rule + +sig2 = Tuple{typeof(my_function), Float64} +is_primitive(DefaultCtx, ReverseMode, sig2) # false - will derive rule +``` + +### **Pattern: Testing a New Rule** +```julia +using Mooncake.TestUtils: test_rule +using Random: Xoshiro + +# Test your new rule: +test_rule( + Xoshiro(123), # RNG for reproducibility + my_func, # Function + arg1, arg2; # Arguments + is_primitive=true, # Expect hand-written rule + perf_flag=:stability_and_allocs, # Check performance +) +``` + +## 14. **Common Pitfalls & Solutions** + +### **Pitfall 1: StackOverflowError when calling build_rrule** +**Cause**: Recursive type without custom tangent type +**Solution**: See docs/src/developer_documentation/custom_tangent_type.md +**Example**: src/rrules/tasks.jl (TaskTangent), ext/MooncakeDynamicExpressionsExt.jl (TangentNode) + +### **Pitfall 2: MissingForeigncallRuleError** +**Cause**: Code calls a `ccall` without a rule +**Solution**: Write rule for Julia function that calls it, or the foreigncall itself +**Example**: src/rrules/blas.jl shows how to write foreigncall rules + +### **Pitfall 3: Type instability in generated code** +**Cause**: Abstract types in primal signature, missing type assertions +**Solution**: Add type assertions in rule, or make primal type-stable +**Check**: Use JET.jl via `test_rule(...; perf_flag=:stability)` + +### **Pitfall 4: Incorrect gradient for mutation** +**Cause**: Forgot to restore primal state in pullback +**Solution**: Save old values before mutation, restore in pullback +**Example**: Any `rrule!!` in src/rrules/blas.jl that saves `_copy = copy(x)` + +### **Pitfall 5: Segfault or weird errors** +**Cause**: Wrong fdata/rdata types in custom rule +**Solution**: Enable debug mode, check types with `verify_fdata_value`, `verify_rdata_value` +**How**: `Config(; debug_mode=true)` or see docs/src/utilities/debug_mode.md + +## 15. **Common Workflows** + +> **Documentation**: docs/src/utilities/ - complete workflow guides + +### **Adding a New Primitive Rule** +1. Determine signature `Tuple{typeof(f), ArgTypes...}` +2. Add `@is_primitive Context Signature` declaration +3. Implement `rrule!!(::CoDual{typeof(f)}, args::CoDual...)` +4. Implement `frule!!(::Dual{typeof(f)}, args::Dual...)` (optional) +5. Add test case to `hand_written_rule_test_cases` +6. Run `test_rule(rng, f, test_args...)` +- **Documentation**: docs/src/utilities/defining_rules.md - complete guide with examples + +### **Debugging AD Issues** +1. Enable debug mode: `Config(; debug_mode=true)` +2. Use `test_rule` to isolate issue +3. Inspect IR with developer tools: `primal_ir`, `fwd_ir`, `rvs_ir` +4. Check `tangent_type`, `fdata_type`, `rdata_type` are correct +5. Verify with `test_data` for custom tangent types +- **Documentation**: + - docs/src/utilities/debugging_and_mwes.md - debugging strategies + - docs/src/utilities/debug_mode.md - using debug mode effectively + +### **Supporting a New Type** +1. Implement `tangent_type(::Type{MyType})` if non-default +2. Implement tangent operations: `zero_tangent_internal`, `randn_tangent_internal`, etc. +3. Implement `fdata`, `rdata`, `tangent(f, r)` if custom splitting needed +4. Add `rrule!!` for `lgetfield`, `lsetfield!`, `_new_` as needed +5. Use `test_data(rng, instance)` to verify +- **Documentation**: docs/src/developer_documentation/custom_tangent_type.md + - Step-by-step guide with complete recursive type example + - Full checklist of required methods + - Appendix with complete implementations + +## 16. **Architecture Decisions (Why Things Are This Way)** + +### **Why FData/RData Split?** +- **Problem**: Passing entire tangents around is expensive +- **Solution**: Split into mutable (fdata - passed by reference) and immutable (rdata - passed by value) +- **Benefit**: Only propagate what's necessary on each pass +- **Doc**: docs/src/understanding_mooncake/rule_system.md - "Representing Gradients" + +### **Why BBCode Instead of Just IRCode?** +- **Problem**: Inserting basic blocks in IRCode is awkward and error-prone +- **Solution**: BBCode uses IDs instead of positions, making insertions safe +- **When Used**: Only reverse-mode (needs complex CFG modifications), forward-mode uses IRCode +- **Doc**: docs/src/developer_documentation/ir_representation.md - "An Alternative IR Datastructure" + +### **Why LazyDerivedRule?** +- **Problem**: Recursive functions cause infinite loop during rule compilation +- **Solution**: Defer rule construction until first call +- **Example**: `f(x) = x > 0 ? f(x-1) : x` would loop forever without lazy construction +- **Doc**: docs/src/developer_documentation/misc_internals_notes.md - "How Recursion Is Handled" + +### **Why Unique Tangent Types?** +- **Problem**: Multiple tangent types → type instability → 100x+ slowdowns +- **Solution**: Each primal type has exactly one tangent type +- **Benefit**: Type-stable AD code, predictable testing, clear interface +- **Doc**: docs/src/understanding_mooncake/rule_system.md - "Why Uniqueness of Type For Tangents" + +### **Why Block Stack?** +- **Problem**: Need to know which block we came from on reverse pass +- **Solution**: Push block ID on forward pass, pop on reverse pass +- **Optimization**: Skip for unique predecessors (often 50%+ of blocks) +- **Location**: src/interpreter/reverse_mode.jl:81 + +### **Why IntrinsicsWrappers Module?** +- **Problem**: All intrinsics have type `Core.IntrinsicFunction` → can't dispatch +- **Solution**: Wrap each in a regular function with unique type +- **Benefit**: Type-stable dispatch in rules +- **Doc**: src/rrules/builtins.jl:44-84 (IntrinsicsWrappers docstring) + +## 17. **Important Invariants** + +> **Documentation**: docs/src/understanding_mooncake/rule_system.md - "Why Uniqueness of Type For Tangents" + +1. **Uniqueness**: `tangent_type(P)` returns exactly one type for each primal type `P` +2. **Reconstruction**: `tangent(fdata(t), rdata(t)) === t` must hold +3. **Type Stability**: If primal is type-stable, AD should be type-stable +4. **State Restoration**: After pullback, all mutated state must be restored +5. **Aliasing Preservation**: Tangent structure must mirror primal aliasing +6. **No Global Mutation**: Functions must not modify global mutable state +- **Documentation**: + - docs/src/known_limitations.md - mutation of globals, recursive types, pointers + - docs/src/understanding_mooncake/rule_system.md - "Why Support Closures But Not Mutable Globals" + +## 18. **Debugging Checklist** + +When AD fails, check in this order: + +1. **Does the primal run?** + ```julia + f(x...) # Must work before differentiation + ``` + +2. **Is there a missing primitive?** + ```julia + # Look for MissingForeigncallRuleError or MissingRuleForBuiltinException + # Solution: Add rule or use @mooncake_overlay to avoid problematic code + ``` + +3. **Are tangent types correct?** + ```julia + using Mooncake: tangent_type, fdata_type, rdata_type + tangent_type(typeof(x)) # Should be a concrete type + ``` + +4. **Enable debug mode** + ```julia + rule = build_rrule(f, x...; debug_mode=true) + # Will catch fdata/rdata type mismatches + ``` + +5. **Inspect generated IR** + ```julia + using Mooncake: fwd_ir, rvs_ir + display(fwd_ir(Tuple{typeof(f), typeof(x)...})) + display(rvs_ir(Tuple{typeof(f), typeof(x)...})) + ``` + +6. **Check for unsupported features** + - See docs/src/known_limitations.md + - Global mutation, PhiCNode, UpsilonNode not supported + +## 19. **Error Types** + +- `MissingForeigncallRuleError` - No rule for ccall (foreigncall.jl:2-33) +- `MissingRuleForBuiltinException` - No rule for builtin (builtins.jl:14-42) +- `MissingIntrinsicWrapperException` - No intrinsic wrapper (builtins.jl:119-128) +- `UnhandledLanguageFeatureException` - Unsupported Julia feature (ir_utils.jl:273-282) +- `MooncakeRuleCompilationError` - Rule compilation failed (reverse_mode.jl:1012-1036) +- `InvalidFDataException` - Invalid forward data (fwds_rvs_data.jl:287-289) +- `InvalidRDataException` - Invalid reverse data (fwds_rvs_data.jl:750-752) +- `ValueAndGradientReturnTypeError` - Wrong return type for gradient (interface.jl:41-52) +- `ValueAndPullbackReturnTypeError` - Unsupported output (interface.jl:54-73) +- `AddToPrimalException` - Constructor issue in testing (tangents.jl:1062-1080) + +## 20. **FAQs for New Contributors** + +### **Q: Where do I start if I want to add support for a new package?** +A: Create an extension in `ext/` following the pattern in `ext/MooncakeNNlibExt.jl`: +1. Import necessary Mooncake utilities +2. Use `@from_rrule` for functions with ChainRules rules +3. Write custom `rrule!!` for anything else +4. Add tests in `test/ext/your_package/` + +### **Q: How do I know if I need a primitive or can rely on derivation?** +A: Run `build_rrule(f, args...)` - if it fails with missing rule error, you need a primitive. If it succeeds but is slow/incorrect, consider adding a performance primitive to `DefaultCtx`. + +### **Q: What's the difference between MinimalCtx and DefaultCtx?** +A: +- **MinimalCtx**: Only rules essential for correctness (e.g., builtins, foreigncalls) +- **DefaultCtx**: Includes performance rules (e.g., optimized BLAS calls) +- Always add correctness rules to MinimalCtx, performance rules to DefaultCtx + +### **Q: Why does my rule work but performance is terrible?** +A: Common causes: +1. Rule isn't marked `@is_primitive` → AD re-derives it every time +2. Type instability in rule → use `@inferred` to check +3. Missing `@inline` on small functions +4. Allocations → use `@allocations` to check, may need manual optimization + +### **Q: How do I test just my new rule?** +A: +```julia +using Mooncake.TestUtils: test_rule +test_rule(Xoshiro(123), my_func, test_args...; is_primitive=true) +``` +Add to `hand_written_rule_test_cases` in your rrules file for CI. + +### **Q: Where are Hessians computed?** +A: Apply Mooncake to itself (forward-over-reverse or reverse-over-reverse): +```julia +# Hessian = derivative of gradient +grad_func = x -> value_and_gradient!!(rule, f, x)[2][2] +hessian_rule = build_rrule(grad_func, x) +``` +See test cases for examples of higher-order AD. + +### **Q: What's the relationship to Enzyme.jl, Zygote.jl, etc?** +A: +- **Enzyme.jl**: Works at LLVM level (C++), handles more but harder to extend +- **Zygote.jl**: Source-to-source like Mooncake, but struggles with mutation +- **ReverseDiff.jl**: Tape-based, can give wrong answers with control flow +- **Mooncake.jl**: Source-to-source, first-class mutation support, pure Julia + +## 21. **Documentation Cross-Reference** + +### **Getting Started** +- **docs/src/index.md** - Project overview, goals, status, getting started +- **docs/src/tutorial.md** - DifferentiationInterface.jl and native API usage examples +- **docs/src/interface.md** - Public API reference (`Config`, `value_and_gradient!!`, etc.) + +### **Understanding AD** +- **docs/src/understanding_mooncake/introduction.md** - Who docs are for, prerequisites +- **docs/src/understanding_mooncake/algorithmic_differentiation.md** - Complete mathematical foundation + - Derivatives: scalar to Hilbert space generalization + - Forward-mode AD explanation + - Reverse-mode AD explanation (what and how) + - Directional derivatives and gradients + - Worked examples with matrices and Julia functions +- **docs/src/understanding_mooncake/rule_system.md** - Rule system specification + - Mathematical model for Julia functions + - Rule interface (forwards and reverse passes) + - CoDual, fdata, rdata system + - Testing rules + - Why uniqueness matters + +### **Practical Guides** +- **docs/src/utilities/defining_rules.md** - Complete rule-writing guide + - Using `@mooncake_overlay` to simplify code + - Using `@zero_adjoint` for zero-derivative functions + - Using `@from_rrule` / `@from_chainrules` to import existing rules + - Writing custom `rrule!!` methods +- **docs/src/utilities/debug_mode.md** - Debug mode guide + - When to enable it + - What it checks + - Performance impact +- **docs/src/utilities/debugging_and_mwes.md** - Debugging and MWE creation + - Using `test_rule` for debugging + - Creating minimal working examples + +### **Developer Internals** +- **docs/src/developer_documentation/running_tests_locally.md** - Testing workflow + - Main testing with Revise.jl + - Extension and integration testing +- **docs/src/developer_documentation/developer_tools.md** - IR inspection tools + - `primal_ir`, `dual_ir`, `fwd_ir`, `rvs_ir` usage +- **docs/src/developer_documentation/tangents.md** - Tangent interface + - `test_tangent_interface`, `test_tangent_splitting`, `test_rule_and_type_interactions` + - Complete interface requirements +- **docs/src/developer_documentation/custom_tangent_type.md** - Custom tangent types + - Why recursive types are challenging + - Step-by-step guide with struct A example + - Complete checklist of required methods + - Full working implementation in appendix +- **docs/src/developer_documentation/ir_representation.md** - IR internals + - Julia's SSA-form IR explained + - Control flow, PhiNodes, basic blocks + - IRCode vs BBCode comparison + - Code transformation examples +- **docs/src/developer_documentation/forwards_mode_design.md** - Forward-mode design (planned) + - Forwards-rule interface + - Hand-written vs derived rules + - Comparison with ForwardDiff.jl +- **docs/src/developer_documentation/reverse_mode_design.md** - Reverse-mode overview + - Compilation process walkthrough + - `build_rrule` flow + - `generate_ir` explanation +- **docs/src/developer_documentation/misc_internals_notes.md** - Implementation notes + - `tangent_type` generated function design + - How recursion is handled via `LazyDerivedRule` + +### **Limitations** +- **docs/src/known_limitations.md** - Known issues and workarounds + - Mutation of global variables + - Passing differentiable data as a type + - Circular references in type declarations + - Tangent generation and pointers + +This comprehensive map provides a complete conceptual understanding of Mooncake.jl's architecture, linking every major concept to specific source locations and their corresponding documentation. From aebb345864d08fa03bc7e3d150a1c5a059c28025 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 9 Nov 2025 11:46:18 +0000 Subject: [PATCH 04/19] update agent memory --- CLAUDE.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index af7eb10b38..3c95f29cdd 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,4 +4,6 @@ read `test/runtests.jl` to understand how to run test before you run any test it is important that the code you write conform to the existing code in this repo -the test infra of Mooncake is strict, and well structured, when you add tests, you must use existing infra as much as you can \ No newline at end of file +the test infra of Mooncake is strict, and well structured, when you add tests, you must use existing infra as much as you can + +you should never commit or push without asking the user \ No newline at end of file From 150721a9e29b9704fee905988b94716fd5709c8d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 10 Nov 2025 11:48:40 +0000 Subject: [PATCH 05/19] try fixing CI error --- src/interpreter/forward_mode.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/interpreter/forward_mode.jl b/src/interpreter/forward_mode.jl index eabe0d6be4..47f5957352 100644 --- a/src/interpreter/forward_mode.jl +++ b/src/interpreter/forward_mode.jl @@ -321,10 +321,8 @@ end @static if isdefined(Core, :EnterNode) function modify_fwd_ad_stmts!( - ::Core.EnterNode, dual_ir::IRCode, ssa::SSAValue, ::Vector{Any}, ::DualInfo + ::Core.EnterNode, ::IRCode, ::SSAValue, ::Vector{Any}, ::DualInfo ) - # Drop typed exception-enter nodes from dual IR to avoid optimiser assertions - replace_call!(dual_ir, ssa, nothing) return nothing end end @@ -405,12 +403,11 @@ function modify_fwd_ad_stmts!( new_undef_inst = new_inst(Expr(:throw_undef_if_not, stmt.args[1], ssa)) CC.insert_node!(dual_ir, ssa, new_undef_inst, ATTACH_AFTER) elseif isexpr(stmt, :enter) - # Drop exception-handling scaffolding from the dual IR. - replace_call!(dual_ir, ssa, nothing) + # Leave this node alone elseif isexpr(stmt, :leave) - replace_call!(dual_ir, ssa, nothing) + # Leave this node alone elseif isexpr(stmt, :pop_exception) - replace_call!(dual_ir, ssa, nothing) + # Leave this node alone elseif isexpr(stmt, :the_exception) # Preserve the primal exception object but give it a zero tangent. inst = CC.NewInstruction(get_ir(info.primal_ir, ssa)) From b0b69e80bf25fffde8dd954e1acaa79281730f4f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 10 Nov 2025 12:04:28 +0000 Subject: [PATCH 06/19] add forward over reverse test --- .github/workflows/CI.yml | 1 + .../differentiation_interface.jl | 13 ------------- test/ext/forward_over_reverse/Project.toml | 5 +++++ .../forward_over_reverse.jl | 17 +++++++++++++++++ 4 files changed, 23 insertions(+), 13 deletions(-) create mode 100644 test/ext/forward_over_reverse/Project.toml create mode 100644 test/ext/forward_over_reverse/forward_over_reverse.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index af20dd53c9..f5142fec67 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -85,6 +85,7 @@ jobs: matrix: test_group: [ {test_type: 'ext', label: 'differentiation_interface'}, + {test_type: 'ext', label: 'forward_over_reverse'}, {test_type: 'ext', label: 'dynamic_expressions'}, {test_type: 'ext', label: 'flux'}, {test_type: 'ext', label: 'function_wrappers'}, diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index 19ad454bc2..17853c3158 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -11,16 +11,3 @@ test_differentiation( excluded=SECOND_ORDER, logging=true, ) - -# Explicit second-order sanity tests for Mooncake forward-over-reverse -@testset "Mooncake second-order examples" begin - backend = SecondOrder(AutoMooncakeForward(), AutoMooncake()) - - # Sum: Hessian is zero - @test DI.hessian(sum, backend, [2.0]) == [0.0] - - # Rosenbrock 2D at [1.2, 1.2] - rosen(z) = (1.0 - z[1])^2 + 100.0 * (z[2] - z[1]^2)^2 - H = DI.hessian(rosen, backend, [1.2, 1.2]) - @test isapprox(H, [1250.0 -480.0; -480.0 200.0]; rtol=1e-10, atol=1e-12) -end diff --git a/test/ext/forward_over_reverse/Project.toml b/test/ext/forward_over_reverse/Project.toml new file mode 100644 index 0000000000..7639dd345f --- /dev/null +++ b/test/ext/forward_over_reverse/Project.toml @@ -0,0 +1,5 @@ +[deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/ext/forward_over_reverse/forward_over_reverse.jl b/test/ext/forward_over_reverse/forward_over_reverse.jl new file mode 100644 index 0000000000..b22074cb74 --- /dev/null +++ b/test/ext/forward_over_reverse/forward_over_reverse.jl @@ -0,0 +1,17 @@ +using Pkg, Test +Pkg.activate(@__DIR__) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) + +using DifferentiationInterface +import DifferentiationInterface as DI +using Mooncake: Mooncake + +@testset "forward-over-reverse via DifferentiationInterface" begin + backend = SecondOrder(AutoMooncakeForward(), AutoMooncake()) + + @test DI.hessian(sum, backend, [2.0]) == [0.0] + + rosen(z) = (1.0 - z[1])^2 + 100.0 * (z[2] - z[1]^2)^2 + H = DI.hessian(rosen, backend, [1.2, 1.2]) + @test isapprox(H, [1250.0 -480.0; -480.0 200.0]; rtol=1e-10, atol=1e-12) +end From 6533eb9d7dd1a0038bbbf41bd8b154f75af9b0e5 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 10 Nov 2025 16:09:21 +0000 Subject: [PATCH 07/19] add more proper tests --- Project.toml | 14 ++- test/interpreter/forward_over_reverse.jl | 104 +++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 test/interpreter/forward_over_reverse.jl diff --git a/Project.toml b/Project.toml index b1bcd12ca1..9aab907f45 100644 --- a/Project.toml +++ b/Project.toml @@ -79,6 +79,7 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -87,4 +88,15 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "Logging", "Pkg", "StableRNGs", "Test"] +test = [ + "AllocCheck", + "Aqua", + "BenchmarkTools", + "DiffTests", + "FiniteDifferences", + "JET", + "Logging", + "Pkg", + "StableRNGs", + "Test", +] diff --git a/test/interpreter/forward_over_reverse.jl b/test/interpreter/forward_over_reverse.jl new file mode 100644 index 0000000000..03f4bcc516 --- /dev/null +++ b/test/interpreter/forward_over_reverse.jl @@ -0,0 +1,104 @@ +using StableRNGs: StableRNG +using Base: IEEEFloat +using FiniteDifferences +using Mooncake: ForwardMode, _typeof + +const HESSIAN_CASE_IDS = Set([3, 4, 13, 34]) +const HESSIAN_RTOL = 1e-6 +const HESSIAN_ATOL = 1e-8 +const HESSIAN_FDM = FiniteDifferences.central_fdm(5, 1) + +function low_level_gradient(rule, f, args...) + return Base.tail(Mooncake.value_and_gradient!!(rule, f, args...)[2]) +end + +function _scalar_output(f, args...) + copied = map(TestUtils._deepcopy, args) + return f(copied...) isa IEEEFloat +end + +function _hessian_supported_arg(x) + x isa Number && return true + x isa AbstractArray && return eltype(typeof(x)) <: Number + x isa Tuple && return all(_hessian_supported_arg, x) + x isa NamedTuple && return all(_hessian_supported_arg, x) + return false +end + +function _select_hessian_cases() + selected = Vector{Tuple{Int,Tuple}}() + for (n, case) in enumerate(TestResources.generate_test_functions()) + n in HESSIAN_CASE_IDS || continue + interface_only, _, _, fx... = case + interface_only && continue + f = fx[1] + args = fx[2:end] + _scalar_output(f, args...) || continue + any(!_hessian_supported_arg(arg) for arg in args) && continue + push!(selected, (n, case)) + end + return selected +end + +_as_tuple(x::Tuple) = x +_as_tuple(x) = (x,) + +function _isapprox_nested(a, b; atol, rtol) + if a isa Number && b isa Number + return isapprox(a, b; atol, rtol) + elseif a isa AbstractArray && b isa AbstractArray + return isapprox(a, b; atol, rtol) + elseif a isa Tuple && b isa Tuple + return length(a) == length(b) && + all(_isapprox_nested(ai, bi; atol, rtol) for (ai, bi) in zip(a, b)) + else + return a == b + end +end + +@testset "forward-over-reverse Hessian AD" begin + cases = _select_hessian_cases() + @info "forward-over-reverse Hessian cases" total = length(cases) + @test !isempty(cases) + for (n, (interface_only, perf_flag, _, fx...)) in cases + interp = Mooncake.get_interpreter(ForwardMode) + f = fx[1] + args = map(TestUtils._deepcopy, fx[2:end]) + rule = Mooncake.build_rrule(fx...) + sig = Tuple{ + typeof(low_level_gradient),typeof(rule),_typeof(f),map(_typeof, args)... + } + grad_rule = Mooncake.build_frule(interp, sig) + rng = StableRNG(0xF0 + n) + dirs = map(arg -> Mooncake.randn_tangent(rng, arg), args) + if any(dir -> dir isa Mooncake.NoTangent, dirs) + @testset "$n - $(_typeof((fx)))" begin + @test_broken true + end + continue + end + + dual_inputs = ( + Mooncake.Dual(low_level_gradient, Mooncake.zero_tangent(low_level_gradient)), + Mooncake.Dual(rule, Mooncake.zero_tangent(rule)), + Mooncake.Dual(f, Mooncake.zero_tangent(f)), + map((arg, dir) -> Mooncake.Dual(arg, dir), args, dirs)..., + ) + dual_result = grad_rule(dual_inputs...) + pushforward = _as_tuple(Mooncake.tangent(dual_result)) + + fd_ref = FiniteDifferences.jvp( + HESSIAN_FDM, + x -> _as_tuple(low_level_gradient(rule, f, x...)), + (Tuple(args), Tuple(dirs)), + ) + reference = _as_tuple(fd_ref) + + @testset "$n - $(_typeof((fx)))" begin + @test length(pushforward) == length(reference) + for (pf, fd) in zip(pushforward, reference) + @test _isapprox_nested(pf, fd; atol=HESSIAN_ATOL, rtol=HESSIAN_RTOL) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d8e05c89ba..dfc8288ed5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,7 @@ include("front_matter.jl") include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "forward_mode.jl")) include(joinpath("interpreter", "reverse_mode.jl")) + include(joinpath("interpreter", "forward_over_reverse.jl")) end include("tools_for_rules.jl") include("interface.jl") From 727481af91619136fe2edddcc7a7a40968e32a8a Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 11 Nov 2025 08:21:34 +0000 Subject: [PATCH 08/19] enable all test cases --- Project.toml | 14 +--- test/interpreter/forward_over_reverse.jl | 101 +++++++++++------------ 2 files changed, 48 insertions(+), 67 deletions(-) diff --git a/Project.toml b/Project.toml index 9aab907f45..b1bcd12ca1 100644 --- a/Project.toml +++ b/Project.toml @@ -79,7 +79,6 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -88,15 +87,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = [ - "AllocCheck", - "Aqua", - "BenchmarkTools", - "DiffTests", - "FiniteDifferences", - "JET", - "Logging", - "Pkg", - "StableRNGs", - "Test", -] +test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "Logging", "Pkg", "StableRNGs", "Test"] diff --git a/test/interpreter/forward_over_reverse.jl b/test/interpreter/forward_over_reverse.jl index 03f4bcc516..ca487cba89 100644 --- a/test/interpreter/forward_over_reverse.jl +++ b/test/interpreter/forward_over_reverse.jl @@ -1,49 +1,29 @@ using StableRNGs: StableRNG using Base: IEEEFloat -using FiniteDifferences -using Mooncake: ForwardMode, _typeof +using Mooncake: ForwardMode, _typeof, _add_to_primal, _scale, _diff -const HESSIAN_CASE_IDS = Set([3, 4, 13, 34]) -const HESSIAN_RTOL = 1e-6 -const HESSIAN_ATOL = 1e-8 -const HESSIAN_FDM = FiniteDifferences.central_fdm(5, 1) +function finite_diff_jvp(func, x, dx, unsafe_perturb::Bool=false) + ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] -function low_level_gradient(rule, f, args...) - return Base.tail(Mooncake.value_and_gradient!!(rule, f, args...)[2]) -end - -function _scalar_output(f, args...) - copied = map(TestUtils._deepcopy, args) - return f(copied...) isa IEEEFloat -end + results = map(ε_list) do ε + x_plus = _add_to_primal(x, _scale(ε, dx), unsafe_perturb) + x_minus = _add_to_primal(x, _scale(-ε, dx), unsafe_perturb) + y_plus = func(x_plus...) + y_minus = func(x_minus...) + return _scale(1 / (2ε), _diff(y_plus, y_minus)) + end -function _hessian_supported_arg(x) - x isa Number && return true - x isa AbstractArray && return eltype(typeof(x)) <: Number - x isa Tuple && return all(_hessian_supported_arg, x) - x isa NamedTuple && return all(_hessian_supported_arg, x) - return false + return results end -function _select_hessian_cases() - selected = Vector{Tuple{Int,Tuple}}() - for (n, case) in enumerate(TestResources.generate_test_functions()) - n in HESSIAN_CASE_IDS || continue - interface_only, _, _, fx... = case - interface_only && continue - f = fx[1] - args = fx[2:end] - _scalar_output(f, args...) || continue - any(!_hessian_supported_arg(arg) for arg in args) && continue - push!(selected, (n, case)) - end - return selected +function low_level_gradient(rule, f, args...) + return Base.tail(Mooncake.value_and_gradient!!(rule, f, args...)[2]) end _as_tuple(x::Tuple) = x _as_tuple(x) = (x,) -function _isapprox_nested(a, b; atol, rtol) +function _isapprox_nested(a, b; atol=1e-8, rtol=1e-6) if a isa Number && b isa Number return isapprox(a, b; atol, rtol) elseif a isa AbstractArray && b isa AbstractArray @@ -57,14 +37,18 @@ function _isapprox_nested(a, b; atol, rtol) end @testset "forward-over-reverse Hessian AD" begin - cases = _select_hessian_cases() - @info "forward-over-reverse Hessian cases" total = length(cases) - @test !isempty(cases) - for (n, (interface_only, perf_flag, _, fx...)) in cases + @testset "$(_typeof((f, x...)))" for (n, (interface_only, perf_flag, bnds, f, x...)) in + collect( + enumerate(TestResources.generate_test_functions()) + ) + # Skip interface-only tests as they don't have implementations + interface_only && continue + + @info "$n: $(_typeof((f, x...)))" + interp = Mooncake.get_interpreter(ForwardMode) - f = fx[1] - args = map(TestUtils._deepcopy, fx[2:end]) - rule = Mooncake.build_rrule(fx...) + args = map(TestUtils._deepcopy, x) + rule = Mooncake.build_rrule(f, x...) sig = Tuple{ typeof(low_level_gradient),typeof(rule),_typeof(f),map(_typeof, args)... } @@ -72,9 +56,7 @@ end rng = StableRNG(0xF0 + n) dirs = map(arg -> Mooncake.randn_tangent(rng, arg), args) if any(dir -> dir isa Mooncake.NoTangent, dirs) - @testset "$n - $(_typeof((fx)))" begin - @test_broken true - end + @test_broken true continue end @@ -87,18 +69,29 @@ end dual_result = grad_rule(dual_inputs...) pushforward = _as_tuple(Mooncake.tangent(dual_result)) - fd_ref = FiniteDifferences.jvp( - HESSIAN_FDM, - x -> _as_tuple(low_level_gradient(rule, f, x...)), - (Tuple(args), Tuple(dirs)), - ) - reference = _as_tuple(fd_ref) + # Use our own finite difference JVP implementation + grad_func = x -> _as_tuple(low_level_gradient(rule, f, x...)) + fd_results_all = finite_diff_jvp(grad_func, args, dirs) + + # Check if any epsilon value gives a close match (following test_utils.jl pattern) + # Convert each result to tuple form for comparison + fd_results_tuples = map(res -> _as_tuple(res), fd_results_all) + + # Check which epsilon values give close results + isapprox_results = map(fd_results_tuples) do fd_ref + return all(_isapprox_nested(pf, fd) # uses default atol=1e-8, rtol=1e-6 + for (pf, fd) in zip(pushforward, fd_ref)) + end - @testset "$n - $(_typeof((fx)))" begin - @test length(pushforward) == length(reference) - for (pf, fd) in zip(pushforward, reference) - @test _isapprox_nested(pf, fd; atol=HESSIAN_ATOL, rtol=HESSIAN_RTOL) + @test length(pushforward) == length(first(fd_results_tuples)) + # At least one epsilon value should give a close result + if !any(isapprox_results) + # If none match, display values for debugging (like test_utils.jl does) + println("No epsilon gave close result. AD vs FD for each epsilon:") + for (i, fd_ref) in enumerate(fd_results_tuples) + println(" ε[$(i)]: AD=$pushforward, FD=$fd_ref") end end + @test any(isapprox_results) end end From b028e78a6dd5182676398eed2e55e48d8f92f3e9 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 14 Nov 2025 10:52:08 +0000 Subject: [PATCH 09/19] remove claude related files --- .claude/settings.local.json | 3 - CLAUDE.md | 9 - MOONCAKE_CODEBASE_MAP.md | 1003 ----------------------------------- 3 files changed, 1015 deletions(-) delete mode 100644 .claude/settings.local.json delete mode 100644 CLAUDE.md delete mode 100644 MOONCAKE_CODEBASE_MAP.md diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 165ebd5c83..0000000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "model": "opus" -} diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 3c95f29cdd..0000000000 --- a/CLAUDE.md +++ /dev/null @@ -1,9 +0,0 @@ -read `MOONCAKE_CODEBASE_MAP.md` to get an idea what are you working with - -read `test/runtests.jl` to understand how to run test before you run any test - -it is important that the code you write conform to the existing code in this repo - -the test infra of Mooncake is strict, and well structured, when you add tests, you must use existing infra as much as you can - -you should never commit or push without asking the user \ No newline at end of file diff --git a/MOONCAKE_CODEBASE_MAP.md b/MOONCAKE_CODEBASE_MAP.md deleted file mode 100644 index 458770cbc3..0000000000 --- a/MOONCAKE_CODEBASE_MAP.md +++ /dev/null @@ -1,1003 +0,0 @@ -# Mooncake.jl Codebase - Comprehensive Conceptual Map - -> **Documentation**: See docs/src/index.md for project overview, docs/src/tutorial.md for usage examples - -## Quick Start Guide - -### **First-Time Readers: Start Here** - -If you're new to the codebase, read in this order: -1. **docs/src/understanding_mooncake/introduction.md** - Prerequisites and what AD is -2. **docs/src/understanding_mooncake/algorithmic_differentiation.md** - Math foundations (skim if familiar with AD) -3. **docs/src/understanding_mooncake/rule_system.md** - Core rule interface (critical!) -4. **src/tangents.jl:1-100** - Skim to understand tangent types -5. **src/fwds_rvs_data.jl:1-100** - Skim to understand fdata/rdata split -6. **Pick a task below** and follow the relevant reading path - -### **Reading Paths by Goal** - -#### **Goal: Use Mooncake.jl** -1. docs/src/tutorial.md - DifferentiationInterface.jl usage -2. docs/src/interface.md - Native API if DI.jl insufficient -3. docs/src/known_limitations.md - What doesn't work - -#### **Goal: Debug an AD Issue** -1. docs/src/utilities/debugging_and_mwes.md - How to create MWEs -2. docs/src/utilities/debug_mode.md - Enable type checking -3. src/developer_tools.jl - Inspect generated IR -4. docs/src/developer_documentation/developer_tools.md - Using IR tools - -#### **Goal: Add a Primitive Rule** -1. docs/src/utilities/defining_rules.md - Complete guide -2. docs/src/understanding_mooncake/rule_system.md - Rule interface -3. Look at similar rule in src/rrules/ for pattern -4. src/test_utils.jl - Use `test_rule` for testing - -#### **Goal: Support a Custom Type** -1. docs/src/developer_documentation/custom_tangent_type.md - Complete guide -2. src/tangents.jl:302-496 - See default tangent_type implementation -3. ext/MooncakeDynamicExpressionsExt.jl - Complex real-world example -4. src/test_utils.jl - Use `test_data` for verification - -#### **Goal: Understand How AD Works Internally** -1. docs/src/understanding_mooncake/algorithmic_differentiation.md - Math -2. docs/src/developer_documentation/reverse_mode_design.md - Compilation overview -3. docs/src/developer_documentation/ir_representation.md - IR basics -4. src/interpreter/reverse_mode.jl:1152-1196 - Read `generate_ir` -5. src/interpreter/reverse_mode.jl:394-861 - Read `make_ad_stmts!` for one statement type - -#### **Goal: Fix a Performance Issue** -1. Profile to identify bottleneck -2. src/developer_tools.jl - Inspect generated IR -3. src/rrules/performance_patches.jl - See example performance rule -4. Consider adding primitive to DefaultCtx (not MinimalCtx) - -## Glossary of Terms - -| Term | Meaning | Location | -|------|---------|----------| -| **Primal** | Original computation being differentiated | Throughout | -| **Tangent** | Derivative information; input/output of `D f[x]` | src/tangents.jl:1-1426 | -| **Cotangent** | Adjoint/gradient information; input/output of `D f[x]*` | Throughout | -| **CoDual** | Pairs primal with fdata for reverse-mode | src/codual.jl:1-124 | -| **Dual** | Pairs primal with tangent for forward-mode | src/dual.jl:1-58 | -| **FData** | Forward data - mutable components of tangent | src/fwds_rvs_data.jl | -| **RData** | Reverse data - immutable components of tangent | src/fwds_rvs_data.jl | -| **Rule** | Function that computes AD (hand-written or derived) | Throughout | -| **Primitive** | Function with hand-written rule | src/interpreter/contexts.jl | -| **Derived Rule** | Auto-generated rule from IR transformation | src/interpreter/reverse_mode.jl | -| **Pullback** | Reverse-pass function that propagates gradients | Throughout | -| **IRCode** | Julia's SSA intermediate representation | Core.Compiler | -| **BBCode** | Mooncake's basic-block IR representation | src/interpreter/bbcode.jl | -| **SSA** | Static Single Assignment - each variable assigned once | Throughout | -| **PhiNode** | Merges values from different control flow paths | IR nodes | -| **Block Stack** | Tracks which blocks were visited (reverse-mode) | src/interpreter/reverse_mode.jl:81 | -| **Activity Analysis** | Determining what's differentiable vs constant | Implicit throughout | -| **SROA** | Scalar Replacement of Aggregates - compiler optimization | Julia compiler | - -## Data Flow Overview - -### **Reverse-Mode: value_and_gradient!! Call Chain** - -``` -User calls: value_and_gradient!!(rule, f, x...) - ↓ -interface.jl:169 → __value_and_gradient!!(rule, CoDual(f, tangent(f)), CoDual(x, tangent(x))...) - ↓ -interface.jl:104 → rule(map(to_fwds, coduals)...) # to_fwds extracts fdata - ↓ -FORWARD PASS (generated by build_rrule): - DerivedRule.fwds_oc(args...) runs, executing: - - Extract shared data from captures - - Create rdata Refs (optimized away by SROA) - - Run transformed statements (rrule!! calls) - - Push to block stack when needed - - Push intermediate values to communication stacks - - Return: CoDual(result, fdata), pullback - ↓ -interface.jl:110 → pullback(one(result)) # or pullback(rdata(ȳ)) - ↓ -REVERSE PASS (generated by build_rrule): - Pullback.pb_oc(dy) runs, executing: - - Pop communication stacks - - Pop block stack to determine control flow - - Run transformed statements in reverse - - Increment rdata Refs - - Return: tuple of rdata for all arguments - ↓ -interface.jl:110 → tangent(fdata, rdata) for each argument - ↓ -Return: (value, (NoTangent(), gradient_x1, gradient_x2, ...)) -``` - -### **Where Time is Spent (Performance Model)** - -1. **First call**: Rule compilation dominates (~90%+ of time) - - `build_rrule` generates and optimizes IR - - Stored in `interp.oc_cache` for reuse - -2. **Subsequent calls**: Execution time depends on: - - **Primitives** (~fast) - Hand-written rules are typically well-optimized - - **Block stack ops** (~10-30% overhead) - Reduced via unique predecessor optimization - - **Communication stacks** (~5-15% overhead) - Reduced via SingletonStack optimization - - **Memory operations** (~fast) - SROA eliminates most Ref allocations - -3. **Slowest operations** (in derived rules): - - Dynamic dispatch (`DynamicDerivedRule`) - - Type-unstable code - - Lots of small functions without primitives - - Large loops with value-dependent control flow - -## Critical Files (Read These First) - -### **Core Type System (Essential)** -1. **src/tangents.jl** - Tangent type system (~1400 lines, skim structure) - - Lines 1-100: Type definitions - - Lines 302-496: `tangent_type` implementation - - Lines 508-605: `zero_tangent_internal` implementation - -2. **src/fwds_rvs_data.jl** - FData/RData split (~1000 lines, skim) - - Lines 1-100: FData/RData type definitions - - Lines 155-202: `fdata_type` implementation - - Lines 433-476: `rdata_type` implementation - -3. **src/codual.jl** - CoDual type (~120 lines, read fully) - -### **Rule System (Essential)** -4. **src/interface.jl** - Public API (read lines 1-200 fully) - - Understanding `value_and_gradient!!` flow is critical - -5. **src/interpreter/contexts.jl** - What makes something primitive (~120 lines, read fully) - -### **Rule Derivation (Important)** -6. **src/interpreter/reverse_mode.jl** - Read selectively: - - Lines 1-200: Data structures (SharedDataPairs, ADInfo, etc.) - - Lines 394-500: `make_ad_stmts!` for one example (e.g., ReturnNode) - - Lines 1044-1196: `build_rrule` and `generate_ir` flow - -7. **src/interpreter/bbcode.jl** - BBCode representation (skim if working on IR) - -### **Example Rules (Learn by Example)** -8. **src/rrules/low_level_maths.jl** - Simple rules (read lines 1-200) -9. **src/rrules/blas.jl** - Complex rules (read one example, e.g., gemm!) -10. **src/rrules/builtins.jl** - Essential primitives (skim IntrinsicsWrappers module) - -## 1. **Core Architecture** - -### **Main Entry Point** -- **src/Mooncake.jl:1-172** - Module definition, includes all submodules, defines core functions (`frule!!`, `rrule!!`, `build_primitive_rrule`) -- **Documentation**: docs/src/index.md (getting started), docs/src/tutorial.md (DifferentiationInterface.jl usage) - -### **Type System Hierarchy** - -#### **Primal-Tangent Pairing** -- **src/dual.jl:1-58** - `Dual{P,T}` for forward-mode (primal + tangent) -- **src/codual.jl:1-124** - `CoDual{Tx,Tdx}` for reverse-mode (primal + fdata) - -#### **Tangent Types** -- **src/tangents.jl:1-1426** - Core tangent system: - - `NoTangent` - for non-differentiable types - - `Tangent{NamedTuple}` - for immutable structs - - `MutableTangent{NamedTuple}` - for mutable structs - - `PossiblyUninitTangent{T}` - for potentially undefined fields - - Functions: `tangent_type`, `zero_tangent`, `randn_tangent`, `increment!!`, `set_to_zero!!` - -#### **FData/RData Splitting** -- **src/fwds_rvs_data.jl:1-1026** - Separates tangents into: - - **FData** (forward data) - mutable/address-identified components, propagated on forward pass - - **RData** (reverse data) - immutable/value-identified components, propagated on reverse pass - - Key functions: `fdata_type`, `rdata_type`, `fdata`, `rdata`, `tangent(f, r)` - - `ZeroRData` handling in **src/interpreter/zero_like_rdata.jl:1-40** -- **Documentation**: docs/src/understanding_mooncake/rule_system.md - "Representing Gradients" section explains fdata/rdata design - -## 2. **Rule System** - -> **Documentation**: docs/src/understanding_mooncake/rule_system.md - complete rule interface specification - -### **Rule Interface** -- **rrule!! signature**: `(::CoDual{typeof(f)}, args::CoDual...) -> (CoDual{output}, pullback_function)` -- **frule!! signature**: `(::Dual{typeof(f)}, args::Dual...) -> Dual{output}` -- **Documentation**: - - docs/src/understanding_mooncake/rule_system.md - "The Rule Interface" sections - - docs/src/utilities/defining_rules.md - how to write custom rules - -### **Primitive Rules** (src/rrules/) -Hand-written rules for Julia primitives: - -| Category | File | Key Functions | -|----------|------|---------------| -| **Builtins** | src/rrules/builtins.jl:1-1000+ | `getfield`, `setfield!`, `tuple`, `===`, `isa`, `typeof`, `svec`, `ifelse` | -| **Intrinsics** | Module `IntrinsicsWrappers` in builtins.jl | `add_float`, `mul_float`, `div_float`, `neg_float`, `sqrt_llvm`, `fma_float`, `bitcast` | -| **Foreign calls** | src/rrules/foreigncall.jl:1-473 | `_foreigncall_`, `pointer_from_objref`, `unsafe_pointer_to_objref`, `unsafe_copyto!` | -| **Construction** | src/rrules/new.jl:1-212 | `_new_` (all object construction), `_splat_new_` | -| **Arrays (1.10)** | src/rrules/array_legacy.jl:1-666 | `arrayref`, `arrayset`, `_deletebeg!`, `_deleteend!`, `_growend!` | -| **Memory (1.11+)** | src/rrules/memory.jl:1-1000+ | `Memory`, `MemoryRef`, `memoryrefget`, `memoryrefset!`, `Array` construction | -| **BLAS** | src/rrules/blas.jl:1-1000+ | `gemm!`, `gemv!`, `symm!`, `symv!`, `trmv!`, `dot`, `nrm2`, `scal!`, `syrk!` | -| **LAPACK** | src/rrules/lapack.jl:1-643 | `getrf!`, `getrs!`, `getri!`, `trtrs!`, `potrf!`, `potrs!` | -| **Linear Algebra** | src/rrules/linear_algebra.jl:1-52 | `exp(::Matrix)` | -| **Low-level Math** | src/rrules/low_level_maths.jl:1-305 | `exp`, `log`, `sin`, `cos`, `tan`, `sqrt`, `cbrt`, `hypot`, etc. | -| **FastMath** | src/rrules/fastmath.jl:1-162 | Fast versions of math functions | -| **TwicePrecision** | src/rrules/twice_precision.jl:1-524 | `TwicePrecision`, `StepRangeLen`, range operations | -| **Random** | src/rrules/random.jl:1-84 | `randn`, `randexp`, `MersenneTwister` | -| **Tasks** | src/rrules/tasks.jl:1-146 | `Task`, `current_task` (limited support) | -| **IdDict** | src/rrules/iddict.jl:1-247 | `IdDict` operations | -| **MistyClosure** | src/rrules/misty_closures.jl:1-162 | Differentiation of closures with captured variables | -| **Performance** | src/rrules/performance_patches.jl:1-72 | Optimized `sum` for arrays | -| **Misc** | src/rrules/misc.jl:1-398 | `lgetfield`, `lsetfield!`, logging, string ops | -| **Avoidance** | src/rrules/avoiding_non_differentiable_code.jl:1-225 | Pointer arithmetic, logging macros, `@zero_derivative` rules | - -### **Rule Derivation** (Automatic Differentiation) - -> **Documentation**: -> - docs/src/developer_documentation/forwards_mode_design.md - forward-mode internals (planned) -> - docs/src/developer_documentation/reverse_mode_design.md - reverse-mode compilation process -> - docs/src/understanding_mooncake/algorithmic_differentiation.md - mathematical foundations - -#### **Forward Mode** -- **src/interpreter/forward_mode.jl:1-507** - Derives `frule!!` from IR: - - `build_frule` - main entry point - - `generate_dual_ir` - transforms IR to compute derivatives - - `DerivedFRule` - wrapper for derived forward rules - - `LazyFRule` - lazy rule construction - - `DynamicFRule` - dynamic dispatch -- **Documentation**: docs/src/developer_documentation/forwards_mode_design.md - detailed design document - -#### **Reverse Mode** -- **src/interpreter/reverse_mode.jl:1-1875** - Derives `rrule!!` from IR: - - `build_rrule` - main entry point (reverse_mode.jl:1044-1144) - - `generate_ir` - creates forward + reverse IR (reverse_mode.jl:1152-1196) - - `make_ad_stmts!` - transforms each IR statement (reverse_mode.jl:394-861) - - `DerivedRule` - wrapper for derived reverse rules (reverse_mode.jl:934-960) - - `Pullback` - callable that runs reverse pass (reverse_mode.jl:918-932) - - `LazyDerivedRule` - lazy rule construction (reverse_mode.jl:1816-1842) - - `DynamicDerivedRule` - dynamic dispatch (reverse_mode.jl:1726-1752) - - `SharedDataPairs` - manages captured data (reverse_mode.jl:13-72) - - `ADInfo` - global context for rule derivation (reverse_mode.jl:123-206) - - `BlockStack` - tracks control flow (reverse_mode.jl:81) -- **Documentation**: - - docs/src/developer_documentation/reverse_mode_design.md - compilation overview - - docs/src/understanding_mooncake/algorithmic_differentiation.md - "Reverse-Mode AD: how does it do it?" - -## 3. **IR Manipulation** - -> **Documentation**: docs/src/developer_documentation/ir_representation.md - comprehensive guide to IRCode vs BBCode - -### **IR Representations** -- **src/interpreter/bbcode.jl:1-1010** - `BBCode` data structure: - - `BBlock` - basic block with unique IDs (bbcode.jl:194-265) - - `ID` - unique identifier for blocks/statements (bbcode.jl:79-86) - - `IDPhiNode`, `IDGotoNode`, `IDGotoIfNot` - ID-based control flow (bbcode.jl:108-138) - - `Switch` - multi-way branch statement (bbcode.jl:160-168) - - Conversion: `BBCode(::IRCode)` and `IRCode(::BBCode)` (bbcode.jl:528-657) -- **Documentation**: docs/src/developer_documentation/ir_representation.md - - Julia's SSA-form IR explanation - - Control flow and PhiNodes - - Code transformation examples (replacing instructions, inserting blocks) - - When to use IRCode vs BBCode - -### **IR Utilities** -- **src/interpreter/ir_utils.jl:1-334** - IR manipulation: - - `stmt` - get statement from IR (ir_utils.jl:9) - - `set_stmt!`, `get_ir`, `set_ir!`, `replace_call!` (ir_utils.jl:18-36) - - `ircode` - construct IRCode for testing (ir_utils.jl:56-65) - - `infer_ir!` - run type inference (ir_utils.jl:101-126) - - `optimise_ir!` - optimization pipeline (ir_utils.jl:146-188) - - `lookup_ir` - get IR from signature/MethodInstance (ir_utils.jl:206-254) - -### **IR Normalization** -- **src/interpreter/ir_normalisation.jl:1-495** - Standardize IR: - - `normalise!` - main entry (ir_normalisation.jl:23-44) - - `foreigncall_to_call` - `:foreigncall` → `_foreigncall_()` call (ir_normalisation.jl:144-158) - - `new_to_call` - `:new` → `_new_()` call (ir_normalisation.jl:218) - - `splatnew_to_call` - `:splatnew` → `_splat_new_()` call (ir_normalisation.jl:229) - - `intrinsic_to_function` - intrinsics → `IntrinsicsWrappers` (ir_normalisation.jl:244-256) - - `lift_getfield_and_others` - constant field access → `lgetfield` (ir_normalisation.jl:267-290) - - `lift_gc_preservation` - GC preservation handling (ir_normalisation.jl:403-407) - - `const_prop_gotoifnots!` - constant propagation for branches (ir_normalisation.jl:416-432) - -### **Compiler Integration** -- **src/interpreter/abstract_interpretation.jl:1-223** - Custom interpreter: - - `MooncakeInterpreter{C,M}` - subtype of `AbstractInterpreter` (abstract_interpretation.jl:27-66) - - `ClosureCacheKey` - cache key for closures (abstract_interpretation.jl:13-16) - - `inlining_policy` - prevents primitive inlining (abstract_interpretation.jl:159-196) - - `get_interpreter` - returns cached interpreter (abstract_interpretation.jl:217-222) - - `GLOBAL_INTERPRETERS` - cached interpreters (abstract_interpretation.jl:204-207) -- **Documentation**: docs/src/developer_documentation/reverse_mode_design.md - explains `MooncakeInterpreter` role - -### **Contexts** -- **src/interpreter/contexts.jl:1-119** - AD contexts: - - `MinimalCtx` - only essential primitives (contexts.jl:8) - - `DefaultCtx` - all performance primitives (contexts.jl:17) - - `ForwardMode`, `ReverseMode` - AD mode markers (contexts.jl:32, 39) - - `is_primitive` - determines if function is primitive (contexts.jl:58-61) - - `@is_primitive` - macro to declare primitives (contexts.jl:69-118) -- **Documentation**: docs/src/developer_documentation/reverse_mode_design.md - distinction between MinimalCtx and DefaultCtx - -### **Compiler Patches** -- **src/interpreter/patch_for_319.jl:1-435** - Workarounds for Julia compiler bugs (issue #319) - -## 4. **Public Interface** - -> **Documentation**: docs/src/interface.md - complete public API reference - -- **src/interface.jl:1-588** - User-facing API: - - `value_and_gradient!!(rule, f, x...)` - compute gradient (interface.jl:169-171) - - `value_and_pullback!!(rule, ȳ, f, x...)` - compute pullback (interface.jl:142-144) - - `value_and_derivative!!(rule, f, x...)` - forward-mode (interface.jl:581) - - `prepare_gradient_cache` - pre-compile for performance (interface.jl:515-522) - - `prepare_pullback_cache` - pre-compile for pullback (interface.jl:439-458) - - `prepare_derivative_cache` - pre-compile for forward-mode (interface.jl:572) - - `__value_and_gradient!!`, `__value_and_pullback!!` - lower-level internal versions -- **Documentation**: - - docs/src/interface.md - public API docstrings - - docs/src/tutorial.md - usage examples with DifferentiationInterface.jl - -- **src/config.jl:1-18** - Configuration: - - `Config(; debug_mode, silence_debug_messages)` -- **Documentation**: docs/src/utilities/debug_mode.md - when and how to use debug mode - -- **src/public.jl:1-15** - Public API macro for Julia 1.11+ - -## 5. **Utilities & Tools** - -### **Rule Definition Helpers** -- **src/tools_for_rules.jl:1-698** - Macros and utilities: - - `@mooncake_overlay` - override function for AD (tools_for_rules.jl:104-112) - - `@zero_derivative` - mark functions with zero derivative (tools_for_rules.jl:248-302) - - `@zero_adjoint` - reverse-mode specific (tools_for_rules.jl:310-312) - - `@from_chainrules` - import ChainRules rrules (tools_for_rules.jl:628-687) - - `@from_rrule` - import specific rrule (tools_for_rules.jl:695-697) - - `zero_adjoint`, `zero_derivative` - function versions (tools_for_rules.jl:148-177) - - `to_cr_tangent`, `mooncake_tangent` - ChainRules conversion (tools_for_rules.jl:323-373) -- **Documentation**: docs/src/utilities/defining_rules.md - - Complete guide to all rule-writing strategies - - `@mooncake_overlay` examples - - `@zero_adjoint` usage - - `@from_rrule` / `@from_chainrules` with worked examples - - When to implement custom `rrule!!` - -### **Testing Infrastructure** -- **src/test_utils.jl:1-1680** - Comprehensive testing: - - `test_rule` - main testing function (test_utils.jl:895-986) - - `test_tangent_interface` - test tangent operations (test_utils.jl:1111-1233) - - `test_tangent_splitting` - test fdata/rdata split (test_utils.jl:1439-1522) - - `test_rule_and_type_interactions` - test primitives work (test_utils.jl:1553-1580) - - `test_data` - combined test (test_utils.jl:1672-1677) - - `has_equal_data` - structural equality (test_utils.jl:201-327) - - `populate_address_map` - track aliasing (test_utils.jl:338-414) -- **Documentation**: - - docs/src/utilities/debugging_and_mwes.md - using `test_rule` for debugging - - docs/src/developer_documentation/tangents.md - testing functions explained - - docs/src/developer_documentation/running_tests_locally.md - local testing workflow - -- **src/test_resources.jl:1-995** - Test data: - - Module `TestResources` with test types (test_resources.jl:8-989) - - `generate_test_functions` - standard test cases (test_resources.jl:699-929) - - Test types: `StructFoo`, `MutableFoo`, `TypeStableMutableStruct`, etc. - -### **Developer Tools** -- **src/developer_tools.jl:1-155** - IR inspection: - - `primal_ir` - get primal IR (developer_tools.jl:22-24) - - `dual_ir` - get forward-mode IR (developer_tools.jl:61-68) - - `fwd_ir` - get forward pass IR (developer_tools.jl:104-111) - - `rvs_ir` - get reverse pass IR (developer_tools.jl:147-154) -- **Documentation**: docs/src/developer_documentation/developer_tools.md - IR inspection guide - -### **General Utilities** -- **src/utils.jl:1-457** - Helper functions: - - `_typeof` - stable typeof (utils.jl:6-8) - - `tuple_map` - specialized tuple mapping (utils.jl:26-50) - - `always_initialised` - field initialization info (utils.jl:218-222) - - `lgetfield` - literal getfield with `Val` (utils.jl:250) - - `lsetfield!` - literal setfield with `Val` (utils.jl:261) - - `_new_` - direct `:new` instruction (utils.jl:268-270) - - `opaque_closure`, `misty_closure` - closure construction (utils.jl:326-367) - -### **Data Structures** -- **src/stack.jl:1-40** - Specialized stack: - - `Stack{T}` - never-deallocating stack for reverse pass (stack.jl:8-34) - - `SingletonStack{T}` - zero-overhead singleton stack (stack.jl:36-39) - -### **Debug Mode** -- **src/debug_mode.jl:1-124** - Runtime type checking: - - `DebugRRule` - wraps rules with type checks (debug_mode.jl:77-95) - - `DebugPullback` - wraps pullbacks with type checks (debug_mode.jl:14-31) - - `DebugFRule` - forward-mode equivalent (debug_mode.jl:2) -- **Documentation**: docs/src/utilities/debug_mode.md - - When to use debug mode - - How it catches type errors - - Performance implications - -## 6. **Extensions** (ext/) - -| Extension | Purpose | Key Types/Functions | -|-----------|---------|---------------------| -| **MooncakeCUDAExt.jl** | CUDA support | `CuArray` tangent ops, allocation rules | -| **MooncakeAllocCheckExt.jl** | Allocation checking | `check_allocs_internal` | -| **MooncakeJETExt.jl** | Type stability | `test_opt_internal`, `report_opt_internal` | -| **MooncakeLuxLibExt.jl** | LuxLib ops | `matmul`, `conv`, `batchnorm` overlays | -| **MooncakeLuxLibSLEEFPiratesExtension.jl** | Fast activations | `sigmoid_fast`, `tanh_fast`, etc. | -| **MooncakeNNlibExt.jl** | NNlib ops | `conv`, `pooling`, `dropout`, `softmax` | -| **MooncakeSpecialFunctionsExt.jl** | Special functions | Bessel, gamma, erf functions via ChainRules | -| **MooncakeFluxExt.jl** | Flux support | Optimized `mse` loss | -| **MooncakeFunctionWrappersExt.jl** | FunctionWrapper | Custom tangent with AD through wrapper | -| **MooncakeDynamicExpressionsExt.jl** | Symbolic expressions | `TangentNode` for expression trees | - -## 7. **Key Algorithms & Concepts** - -### **Forward-Mode AD Workflow** -1. `build_frule(sig)` → generates rule -2. `lookup_ir` → get primal IR -3. `normalise!` → standardize IR -4. Transform each statement: Arguments+1, wrap in `Dual`, replace calls with `frule!!` -5. `optimise_ir!` → optimize generated IR -6. Wrap in `MistyClosure` → `DerivedFRule` - -### **Reverse-Mode AD Workflow** -1. `build_rrule(sig)` → generates rule -2. `lookup_ir` → get primal IR -3. `normalise!` → standardize IR -4. `BBCode(ir)` → convert to basic block form -5. `ADInfo` construction → setup metadata -6. `make_ad_stmts!` → transform each statement into forward/reverse instructions -7. `forwards_pass_ir` → generate forward pass (reverse_mode.jl:1294-1356) -8. `pullback_ir` → generate reverse pass (reverse_mode.jl:1376-1523) -9. `optimise_ir!` → optimize both passes -10. Wrap in `MistyClosure`s → `DerivedRule` + `Pullback` - -### **Critical Implementation Details** - -**Block Stack** (reverse_mode.jl:81): -- Tracks which blocks were visited during forward pass -- Used to determine control flow on reverse pass -- Optimized away for unique predecessors - -**Shared Data** (reverse_mode.jl:13-72): -- Data shared between forward/reverse passes -- Stored in `OpaqueClosure` captures -- Contains block stack, communication stacks, lazy zero rdata - -**Communication Channels** (reverse_mode.jl:1251-1287): -- Per-block stacks storing intermediate values -- Push on forward pass, pop on reverse pass -- Optimized to `SingletonStack` when possible - -**RData Refs** (reverse_mode.jl:259-273): -- Each SSA/Argument gets a `Ref` to accumulate gradients -- Initialized to zero, incremented during reverse pass -- Optimized away by SROA pass - -## 8. **File Organization Summary** - -``` -Mooncake.jl/ -├── src/ -│ ├── Mooncake.jl # Main module -│ ├── tangents.jl # Tangent type system -│ ├── fwds_rvs_data.jl # FData/RData splitting -│ ├── dual.jl # Forward-mode Dual type -│ ├── codual.jl # Reverse-mode CoDual type -│ ├── interface.jl # Public API -│ ├── config.jl # Configuration -│ ├── debug_mode.jl # Runtime type checking -│ ├── stack.jl # Block stack -│ ├── utils.jl # General utilities -│ ├── tools_for_rules.jl # Rule definition macros -│ ├── test_utils.jl # Testing infrastructure -│ ├── test_resources.jl # Test data -│ ├── developer_tools.jl # IR inspection tools -│ ├── public.jl # Public API declarations -│ ├── interpreter/ -│ │ ├── contexts.jl # AD contexts -│ │ ├── abstract_interpretation.jl # Custom interpreter -│ │ ├── bbcode.jl # BBCode IR representation -│ │ ├── ir_utils.jl # IR manipulation -│ │ ├── ir_normalisation.jl # IR standardization -│ │ ├── forward_mode.jl # Forward-mode derivation -│ │ ├── reverse_mode.jl # Reverse-mode derivation -│ │ ├── zero_like_rdata.jl # ZeroRData utilities -│ │ └── patch_for_319.jl # Compiler bug workarounds -│ └── rrules/ -│ ├── builtins.jl # Core built-in functions -│ ├── foreigncall.jl # ccall handling -│ ├── new.jl # Object construction -│ ├── misc.jl # lgetfield, lsetfield!, etc. -│ ├── blas.jl # BLAS operations -│ ├── lapack.jl # LAPACK operations -│ ├── linear_algebra.jl # High-level LinAlg -│ ├── low_level_maths.jl # Math functions -│ ├── fastmath.jl # FastMath functions -│ ├── array_legacy.jl # Array ops (1.10) -│ ├── memory.jl # Memory/Array ops (1.11+) -│ ├── random.jl # Random number generation -│ ├── twice_precision.jl # TwicePrecision/ranges -│ ├── tasks.jl # Task (limited) -│ ├── iddict.jl # IdDict operations -│ ├── misty_closures.jl # Closure differentiation -│ ├── performance_patches.jl # Performance optimizations -│ ├── avoiding_non_differentiable_code.jl # Zero derivative rules -│ └── dispatch_doctor.jl # DispatchDoctor integration -├── ext/ # Package extensions -├── test/ # Test suite (mirrors src/) -└── docs/ # Documentation -``` - -## 9. **Documentation Structure** (docs/src/) - -### **User-Facing** -- `index.md` - Project overview, getting started, project status -- `tutorial.md` - DifferentiationInterface.jl and native API usage -- `interface.md` - Public API documentation -- `known_limitations.md` - Mutation of globals, recursive types, pointers - -### **Understanding Mooncake** -- `understanding_mooncake/introduction.md` - Prerequisites, who docs are for -- `understanding_mooncake/algorithmic_differentiation.md` - Mathematical foundations - - Fréchet derivatives, adjoints, tangents - - Forward vs reverse mode - - Gradients and directional derivatives -- `understanding_mooncake/rule_system.md` - Core rule interface - - `rrule!!` specification - - CoDual, fdata/rdata system - - Testing with `test_rule` - -### **Utilities** -- `utilities/defining_rules.md` - How to write rules - - `@mooncake_overlay` - code simplification - - `@zero_adjoint` - zero derivative functions - - `@from_rrule` - import ChainRules - - Adding custom `rrule!!` methods -- `utilities/debug_mode.md` - `DebugRRule` for type checking -- `utilities/debugging_and_mwes.md` - Using `TestUtils.test_rule` - -### **Developer Documentation** -- `developer_documentation/running_tests_locally.md` - Test workflow with Revise.jl -- `developer_documentation/developer_tools.md` - IR inspection tools -- `developer_documentation/tangents.md` - Tangent type interface requirements -- `developer_documentation/custom_tangent_type.md` - Detailed guide for recursive types - - Complete worked example with `struct A` containing self-reference - - All required methods: `zero_tangent_internal`, `randn_tangent_internal`, `increment_internal!!`, etc. -- `developer_documentation/ir_representation.md` - IRCode vs BBCode - - SSA-form IR explanation - - Control flow and PhiNodes - - Code transformation examples -- `developer_documentation/forwards_mode_design.md` - Forward-mode AD design (unimplemented) -- `developer_documentation/reverse_mode_design.md` - Compilation process overview -- `developer_documentation/misc_internals_notes.md` - Implementation notes - - `tangent_type` generated function design - - Recursion handling via `LazyDerivedRule` - -## 10. **Testing** (test/) - -### **Test Organization** -- **test/front_matter.jl:1-163** - Common test setup, determines test group -- **test/runtests.jl:1-71** - Main test runner with group selection -- **test/run_extra.jl:1-4** - Integration test runner - -### **Test Files** (mirror src/) -- `test/tangents.jl` - Tangent type tests -- `test/codual.jl` - CoDual tests -- `test/interface.jl` - Public API tests -- `test/rrules/*.jl` - Tests for each rrules file -- `test/interpreter/*.jl` - Tests for interpreter components -- `test/integration_testing/` - Integration with other packages -- `test/ext/` - Extension tests - -## 11. **Worked Example: Tracing a Simple Gradient** - -Let's trace `value_and_gradient!!(rule, f, x)` where `f(x) = sin(x)` and `x = 5.0`: - -### **Preparation** (one-time cost) -```julia -rule = build_rrule(f, 5.0) # In src/interpreter/reverse_mode.jl:1044 -``` -1. `lookup_ir` gets IR for `f(::Float64)` → finds `sin` call -2. `is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(sin), Float64})` → `true` (from low_level_maths.jl) -3. Since it's primitive, `build_rrule` returns `rrule!!` directly (no IR transformation needed) - -### **Execution** (fast, repeated calls) -```julia -value_and_gradient!!(rule, f, 5.0) # In src/interface.jl:169 -``` -1. Creates `CoDual(f, NoTangent())` and `CoDual(5.0, 0.0)` (zero tangent) -2. Calls `rule(to_fwds(coduals)...)` → `rrule!!(CoDual(f, NoFData()), CoDual(5.0, NoFData()))` -3. **Forward pass** (in src/rrules/low_level_maths.jl): - - Computes `y = sin(5.0) = -0.9589...` - - Returns `CoDual(-0.9589, NoFData())` and pullback closure -4. **Reverse pass** - calls `pullback(1.0)`: - - Computes `∂x = cos(5.0) * 1.0 = 0.2836...` - - Returns `(NoRData(), 0.2836...)` -5. Reconstructs result: `(-0.9589, (NoTangent(), 0.2836))` - -### **Worked Example: Derived Rule** - -For `g(x) = sin(cos(x))` with `x = 5.0`: - -**build_rrule generates**: -```julia -# Forward pass IR (simplified): -%1 = rrule!!(zero_fcodual(cos), CoDual(x, NoFData())) # y=%1[1], pb_cos=%1[2] -%2 = rrule!!(zero_fcodual(sin), %1[1]) # result=%2[1], pb_sin=%2[2] -push!(comms_stack, %1[2]) # Store cos's pullback -push!(comms_stack, %2[2]) # Store sin's pullback -return %2[1] # Return CoDual(sin(cos(x)), NoFData()) - -# Reverse pass IR (simplified): -pb_sin = pop!(comms_stack) -pb_cos = pop!(comms_stack) -dy_inner = pb_sin(dy) # dy w.r.t cos(x) -dx = pb_cos(dy_inner[2]) # dy w.r.t x -return (NoRData(), dx[2]) # (df, dx) -``` - -## 12. **Key Concepts & Terminology** - -> **Documentation**: docs/src/understanding_mooncake/algorithmic_differentiation.md - complete mathematical treatment - -### **Mathematical Foundation** -- **Fréchet Derivative**: Linear operator `D f[x] : X → Y` satisfying `df = D f[x](dx)` -- **Adjoint**: Linear operator `D f[x]* : Y → X` satisfying `⟨D f[x]*(ȳ), ẋ⟩ = ⟨ȳ, D f[x](ẋ)⟩` -- **Tangent**: Input/output of derivatives (denoted with dot: `ẋ`, `ẏ`) -- **Cotangent/Gradient**: Input/output of adjoints (denoted with bar: `x̄`, `ȳ`) -- **Documentation**: docs/src/understanding_mooncake/algorithmic_differentiation.md - - Derivatives section - scalar to general Hilbert spaces - - Worked examples with matrices and mutable data - - Chain rule and how forward/reverse mode work - - Directional derivatives and gradients - -### **Implementation Details** -- **Primitive**: Function with hand-written rule, not derived from IR -- **Derived Rule**: Rule automatically generated from IR transformation -- **Activity Analysis**: Distinguishing differentiable from non-differentiable data -- **Unique Predecessor**: Block with only one possible incoming edge (optimization opportunity) -- **Lazy Rule Construction**: Defer rule compilation until first use (handles recursion) -- **Dynamic Dispatch**: Runtime rule selection based on argument types -- **Static Dispatch**: Compile-time rule selection via `:invoke` nodes - -### **Performance Optimizations** -1. **SROA** (Scalar Replacement of Aggregates) - eliminates `Ref` allocations -2. **Singleton Type Optimization** - avoids storing singleton pullbacks -3. **Block Stack Elimination** - skips block tracking for unique predecessors -4. **Constant Propagation** - folds constants, eliminates dead branches -5. **Lazy Zero RData** - defers zero construction until needed - -## 13. **Common Patterns & Idioms** - -### **Pattern: Accessing Rule-Generated IR** -```julia -# See what IR Mooncake generates: -using Mooncake: get_interpreter, ReverseMode, primal_ir, fwd_ir, rvs_ir - -sig = Tuple{typeof(sin), Float64} -interp = get_interpreter(ReverseMode) - -primal_ir(interp, sig) # Original function IR -fwd_ir(sig) # Forward pass IR -rvs_ir(sig) # Reverse pass IR (pullback) -``` - -### **Pattern: CoDual Construction** -```julia -# Manual CoDual construction (rarely needed): -x = [1.0, 2.0, 3.0] -dx = zero_tangent(x) # Create tangent -codual = CoDual(x, fdata(dx)) # Pair primal with fdata - -# Usually automatic: -value_and_gradient!!(rule, f, x) # Handles CoDual construction internally -``` - -### **Pattern: Checking if Something is Primitive** -```julia -using Mooncake: is_primitive, DefaultCtx, ReverseMode - -sig = Tuple{typeof(sin), Float64} -is_primitive(DefaultCtx, ReverseMode, sig) # true - has hand-written rule - -sig2 = Tuple{typeof(my_function), Float64} -is_primitive(DefaultCtx, ReverseMode, sig2) # false - will derive rule -``` - -### **Pattern: Testing a New Rule** -```julia -using Mooncake.TestUtils: test_rule -using Random: Xoshiro - -# Test your new rule: -test_rule( - Xoshiro(123), # RNG for reproducibility - my_func, # Function - arg1, arg2; # Arguments - is_primitive=true, # Expect hand-written rule - perf_flag=:stability_and_allocs, # Check performance -) -``` - -## 14. **Common Pitfalls & Solutions** - -### **Pitfall 1: StackOverflowError when calling build_rrule** -**Cause**: Recursive type without custom tangent type -**Solution**: See docs/src/developer_documentation/custom_tangent_type.md -**Example**: src/rrules/tasks.jl (TaskTangent), ext/MooncakeDynamicExpressionsExt.jl (TangentNode) - -### **Pitfall 2: MissingForeigncallRuleError** -**Cause**: Code calls a `ccall` without a rule -**Solution**: Write rule for Julia function that calls it, or the foreigncall itself -**Example**: src/rrules/blas.jl shows how to write foreigncall rules - -### **Pitfall 3: Type instability in generated code** -**Cause**: Abstract types in primal signature, missing type assertions -**Solution**: Add type assertions in rule, or make primal type-stable -**Check**: Use JET.jl via `test_rule(...; perf_flag=:stability)` - -### **Pitfall 4: Incorrect gradient for mutation** -**Cause**: Forgot to restore primal state in pullback -**Solution**: Save old values before mutation, restore in pullback -**Example**: Any `rrule!!` in src/rrules/blas.jl that saves `_copy = copy(x)` - -### **Pitfall 5: Segfault or weird errors** -**Cause**: Wrong fdata/rdata types in custom rule -**Solution**: Enable debug mode, check types with `verify_fdata_value`, `verify_rdata_value` -**How**: `Config(; debug_mode=true)` or see docs/src/utilities/debug_mode.md - -## 15. **Common Workflows** - -> **Documentation**: docs/src/utilities/ - complete workflow guides - -### **Adding a New Primitive Rule** -1. Determine signature `Tuple{typeof(f), ArgTypes...}` -2. Add `@is_primitive Context Signature` declaration -3. Implement `rrule!!(::CoDual{typeof(f)}, args::CoDual...)` -4. Implement `frule!!(::Dual{typeof(f)}, args::Dual...)` (optional) -5. Add test case to `hand_written_rule_test_cases` -6. Run `test_rule(rng, f, test_args...)` -- **Documentation**: docs/src/utilities/defining_rules.md - complete guide with examples - -### **Debugging AD Issues** -1. Enable debug mode: `Config(; debug_mode=true)` -2. Use `test_rule` to isolate issue -3. Inspect IR with developer tools: `primal_ir`, `fwd_ir`, `rvs_ir` -4. Check `tangent_type`, `fdata_type`, `rdata_type` are correct -5. Verify with `test_data` for custom tangent types -- **Documentation**: - - docs/src/utilities/debugging_and_mwes.md - debugging strategies - - docs/src/utilities/debug_mode.md - using debug mode effectively - -### **Supporting a New Type** -1. Implement `tangent_type(::Type{MyType})` if non-default -2. Implement tangent operations: `zero_tangent_internal`, `randn_tangent_internal`, etc. -3. Implement `fdata`, `rdata`, `tangent(f, r)` if custom splitting needed -4. Add `rrule!!` for `lgetfield`, `lsetfield!`, `_new_` as needed -5. Use `test_data(rng, instance)` to verify -- **Documentation**: docs/src/developer_documentation/custom_tangent_type.md - - Step-by-step guide with complete recursive type example - - Full checklist of required methods - - Appendix with complete implementations - -## 16. **Architecture Decisions (Why Things Are This Way)** - -### **Why FData/RData Split?** -- **Problem**: Passing entire tangents around is expensive -- **Solution**: Split into mutable (fdata - passed by reference) and immutable (rdata - passed by value) -- **Benefit**: Only propagate what's necessary on each pass -- **Doc**: docs/src/understanding_mooncake/rule_system.md - "Representing Gradients" - -### **Why BBCode Instead of Just IRCode?** -- **Problem**: Inserting basic blocks in IRCode is awkward and error-prone -- **Solution**: BBCode uses IDs instead of positions, making insertions safe -- **When Used**: Only reverse-mode (needs complex CFG modifications), forward-mode uses IRCode -- **Doc**: docs/src/developer_documentation/ir_representation.md - "An Alternative IR Datastructure" - -### **Why LazyDerivedRule?** -- **Problem**: Recursive functions cause infinite loop during rule compilation -- **Solution**: Defer rule construction until first call -- **Example**: `f(x) = x > 0 ? f(x-1) : x` would loop forever without lazy construction -- **Doc**: docs/src/developer_documentation/misc_internals_notes.md - "How Recursion Is Handled" - -### **Why Unique Tangent Types?** -- **Problem**: Multiple tangent types → type instability → 100x+ slowdowns -- **Solution**: Each primal type has exactly one tangent type -- **Benefit**: Type-stable AD code, predictable testing, clear interface -- **Doc**: docs/src/understanding_mooncake/rule_system.md - "Why Uniqueness of Type For Tangents" - -### **Why Block Stack?** -- **Problem**: Need to know which block we came from on reverse pass -- **Solution**: Push block ID on forward pass, pop on reverse pass -- **Optimization**: Skip for unique predecessors (often 50%+ of blocks) -- **Location**: src/interpreter/reverse_mode.jl:81 - -### **Why IntrinsicsWrappers Module?** -- **Problem**: All intrinsics have type `Core.IntrinsicFunction` → can't dispatch -- **Solution**: Wrap each in a regular function with unique type -- **Benefit**: Type-stable dispatch in rules -- **Doc**: src/rrules/builtins.jl:44-84 (IntrinsicsWrappers docstring) - -## 17. **Important Invariants** - -> **Documentation**: docs/src/understanding_mooncake/rule_system.md - "Why Uniqueness of Type For Tangents" - -1. **Uniqueness**: `tangent_type(P)` returns exactly one type for each primal type `P` -2. **Reconstruction**: `tangent(fdata(t), rdata(t)) === t` must hold -3. **Type Stability**: If primal is type-stable, AD should be type-stable -4. **State Restoration**: After pullback, all mutated state must be restored -5. **Aliasing Preservation**: Tangent structure must mirror primal aliasing -6. **No Global Mutation**: Functions must not modify global mutable state -- **Documentation**: - - docs/src/known_limitations.md - mutation of globals, recursive types, pointers - - docs/src/understanding_mooncake/rule_system.md - "Why Support Closures But Not Mutable Globals" - -## 18. **Debugging Checklist** - -When AD fails, check in this order: - -1. **Does the primal run?** - ```julia - f(x...) # Must work before differentiation - ``` - -2. **Is there a missing primitive?** - ```julia - # Look for MissingForeigncallRuleError or MissingRuleForBuiltinException - # Solution: Add rule or use @mooncake_overlay to avoid problematic code - ``` - -3. **Are tangent types correct?** - ```julia - using Mooncake: tangent_type, fdata_type, rdata_type - tangent_type(typeof(x)) # Should be a concrete type - ``` - -4. **Enable debug mode** - ```julia - rule = build_rrule(f, x...; debug_mode=true) - # Will catch fdata/rdata type mismatches - ``` - -5. **Inspect generated IR** - ```julia - using Mooncake: fwd_ir, rvs_ir - display(fwd_ir(Tuple{typeof(f), typeof(x)...})) - display(rvs_ir(Tuple{typeof(f), typeof(x)...})) - ``` - -6. **Check for unsupported features** - - See docs/src/known_limitations.md - - Global mutation, PhiCNode, UpsilonNode not supported - -## 19. **Error Types** - -- `MissingForeigncallRuleError` - No rule for ccall (foreigncall.jl:2-33) -- `MissingRuleForBuiltinException` - No rule for builtin (builtins.jl:14-42) -- `MissingIntrinsicWrapperException` - No intrinsic wrapper (builtins.jl:119-128) -- `UnhandledLanguageFeatureException` - Unsupported Julia feature (ir_utils.jl:273-282) -- `MooncakeRuleCompilationError` - Rule compilation failed (reverse_mode.jl:1012-1036) -- `InvalidFDataException` - Invalid forward data (fwds_rvs_data.jl:287-289) -- `InvalidRDataException` - Invalid reverse data (fwds_rvs_data.jl:750-752) -- `ValueAndGradientReturnTypeError` - Wrong return type for gradient (interface.jl:41-52) -- `ValueAndPullbackReturnTypeError` - Unsupported output (interface.jl:54-73) -- `AddToPrimalException` - Constructor issue in testing (tangents.jl:1062-1080) - -## 20. **FAQs for New Contributors** - -### **Q: Where do I start if I want to add support for a new package?** -A: Create an extension in `ext/` following the pattern in `ext/MooncakeNNlibExt.jl`: -1. Import necessary Mooncake utilities -2. Use `@from_rrule` for functions with ChainRules rules -3. Write custom `rrule!!` for anything else -4. Add tests in `test/ext/your_package/` - -### **Q: How do I know if I need a primitive or can rely on derivation?** -A: Run `build_rrule(f, args...)` - if it fails with missing rule error, you need a primitive. If it succeeds but is slow/incorrect, consider adding a performance primitive to `DefaultCtx`. - -### **Q: What's the difference between MinimalCtx and DefaultCtx?** -A: -- **MinimalCtx**: Only rules essential for correctness (e.g., builtins, foreigncalls) -- **DefaultCtx**: Includes performance rules (e.g., optimized BLAS calls) -- Always add correctness rules to MinimalCtx, performance rules to DefaultCtx - -### **Q: Why does my rule work but performance is terrible?** -A: Common causes: -1. Rule isn't marked `@is_primitive` → AD re-derives it every time -2. Type instability in rule → use `@inferred` to check -3. Missing `@inline` on small functions -4. Allocations → use `@allocations` to check, may need manual optimization - -### **Q: How do I test just my new rule?** -A: -```julia -using Mooncake.TestUtils: test_rule -test_rule(Xoshiro(123), my_func, test_args...; is_primitive=true) -``` -Add to `hand_written_rule_test_cases` in your rrules file for CI. - -### **Q: Where are Hessians computed?** -A: Apply Mooncake to itself (forward-over-reverse or reverse-over-reverse): -```julia -# Hessian = derivative of gradient -grad_func = x -> value_and_gradient!!(rule, f, x)[2][2] -hessian_rule = build_rrule(grad_func, x) -``` -See test cases for examples of higher-order AD. - -### **Q: What's the relationship to Enzyme.jl, Zygote.jl, etc?** -A: -- **Enzyme.jl**: Works at LLVM level (C++), handles more but harder to extend -- **Zygote.jl**: Source-to-source like Mooncake, but struggles with mutation -- **ReverseDiff.jl**: Tape-based, can give wrong answers with control flow -- **Mooncake.jl**: Source-to-source, first-class mutation support, pure Julia - -## 21. **Documentation Cross-Reference** - -### **Getting Started** -- **docs/src/index.md** - Project overview, goals, status, getting started -- **docs/src/tutorial.md** - DifferentiationInterface.jl and native API usage examples -- **docs/src/interface.md** - Public API reference (`Config`, `value_and_gradient!!`, etc.) - -### **Understanding AD** -- **docs/src/understanding_mooncake/introduction.md** - Who docs are for, prerequisites -- **docs/src/understanding_mooncake/algorithmic_differentiation.md** - Complete mathematical foundation - - Derivatives: scalar to Hilbert space generalization - - Forward-mode AD explanation - - Reverse-mode AD explanation (what and how) - - Directional derivatives and gradients - - Worked examples with matrices and Julia functions -- **docs/src/understanding_mooncake/rule_system.md** - Rule system specification - - Mathematical model for Julia functions - - Rule interface (forwards and reverse passes) - - CoDual, fdata, rdata system - - Testing rules - - Why uniqueness matters - -### **Practical Guides** -- **docs/src/utilities/defining_rules.md** - Complete rule-writing guide - - Using `@mooncake_overlay` to simplify code - - Using `@zero_adjoint` for zero-derivative functions - - Using `@from_rrule` / `@from_chainrules` to import existing rules - - Writing custom `rrule!!` methods -- **docs/src/utilities/debug_mode.md** - Debug mode guide - - When to enable it - - What it checks - - Performance impact -- **docs/src/utilities/debugging_and_mwes.md** - Debugging and MWE creation - - Using `test_rule` for debugging - - Creating minimal working examples - -### **Developer Internals** -- **docs/src/developer_documentation/running_tests_locally.md** - Testing workflow - - Main testing with Revise.jl - - Extension and integration testing -- **docs/src/developer_documentation/developer_tools.md** - IR inspection tools - - `primal_ir`, `dual_ir`, `fwd_ir`, `rvs_ir` usage -- **docs/src/developer_documentation/tangents.md** - Tangent interface - - `test_tangent_interface`, `test_tangent_splitting`, `test_rule_and_type_interactions` - - Complete interface requirements -- **docs/src/developer_documentation/custom_tangent_type.md** - Custom tangent types - - Why recursive types are challenging - - Step-by-step guide with struct A example - - Complete checklist of required methods - - Full working implementation in appendix -- **docs/src/developer_documentation/ir_representation.md** - IR internals - - Julia's SSA-form IR explained - - Control flow, PhiNodes, basic blocks - - IRCode vs BBCode comparison - - Code transformation examples -- **docs/src/developer_documentation/forwards_mode_design.md** - Forward-mode design (planned) - - Forwards-rule interface - - Hand-written vs derived rules - - Comparison with ForwardDiff.jl -- **docs/src/developer_documentation/reverse_mode_design.md** - Reverse-mode overview - - Compilation process walkthrough - - `build_rrule` flow - - `generate_ir` explanation -- **docs/src/developer_documentation/misc_internals_notes.md** - Implementation notes - - `tangent_type` generated function design - - How recursion is handled via `LazyDerivedRule` - -### **Limitations** -- **docs/src/known_limitations.md** - Known issues and workarounds - - Mutation of global variables - - Passing differentiable data as a type - - Circular references in type declarations - - Tangent generation and pointers - -This comprehensive map provides a complete conceptual understanding of Mooncake.jl's architecture, linking every major concept to specific source locations and their corresponding documentation. From 324b121ac1ff075fc8a9c6480c3295166a6717bd Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 14 Nov 2025 10:56:22 +0000 Subject: [PATCH 10/19] remove repeated tests --- .github/workflows/CI.yml | 1 - .../differentiation_interface.jl | 4 +--- test/ext/forward_over_reverse/Project.toml | 5 ----- .../forward_over_reverse.jl | 17 ----------------- 4 files changed, 1 insertion(+), 26 deletions(-) delete mode 100644 test/ext/forward_over_reverse/Project.toml delete mode 100644 test/ext/forward_over_reverse/forward_over_reverse.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f5142fec67..af20dd53c9 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -85,7 +85,6 @@ jobs: matrix: test_group: [ {test_type: 'ext', label: 'differentiation_interface'}, - {test_type: 'ext', label: 'forward_over_reverse'}, {test_type: 'ext', label: 'dynamic_expressions'}, {test_type: 'ext', label: 'flux'}, {test_type: 'ext', label: 'function_wrappers'}, diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index 17853c3158..4f47e24091 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -1,13 +1,11 @@ -using Pkg, Test +using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using DifferentiationInterface, DifferentiationInterfaceTest -import DifferentiationInterface as DI using Mooncake: Mooncake test_differentiation( [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; - excluded=SECOND_ORDER, logging=true, ) diff --git a/test/ext/forward_over_reverse/Project.toml b/test/ext/forward_over_reverse/Project.toml deleted file mode 100644 index 7639dd345f..0000000000 --- a/test/ext/forward_over_reverse/Project.toml +++ /dev/null @@ -1,5 +0,0 @@ -[deps] -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/ext/forward_over_reverse/forward_over_reverse.jl b/test/ext/forward_over_reverse/forward_over_reverse.jl deleted file mode 100644 index b22074cb74..0000000000 --- a/test/ext/forward_over_reverse/forward_over_reverse.jl +++ /dev/null @@ -1,17 +0,0 @@ -using Pkg, Test -Pkg.activate(@__DIR__) -Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) - -using DifferentiationInterface -import DifferentiationInterface as DI -using Mooncake: Mooncake - -@testset "forward-over-reverse via DifferentiationInterface" begin - backend = SecondOrder(AutoMooncakeForward(), AutoMooncake()) - - @test DI.hessian(sum, backend, [2.0]) == [0.0] - - rosen(z) = (1.0 - z[1])^2 + 100.0 * (z[2] - z[1]^2)^2 - H = DI.hessian(rosen, backend, [1.2, 1.2]) - @test isapprox(H, [1250.0 -480.0; -480.0 200.0]; rtol=1e-10, atol=1e-12) -end From 2acded61d5f1d7beefacb91130901ac863cf827b Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 14 Nov 2025 12:40:24 +0000 Subject: [PATCH 11/19] format --- .../ext/differentiation_interface/differentiation_interface.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index 4f47e24091..3a375299ee 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -6,6 +6,5 @@ using DifferentiationInterface, DifferentiationInterfaceTest using Mooncake: Mooncake test_differentiation( - [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; - logging=true, + [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; logging=true ) From 933efc42daead095840d95d9c83996f0502ba107 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 14 Nov 2025 13:02:06 +0000 Subject: [PATCH 12/19] fix format --- test/interpreter/forward_over_reverse.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/interpreter/forward_over_reverse.jl b/test/interpreter/forward_over_reverse.jl index ca487cba89..06c7368b28 100644 --- a/test/interpreter/forward_over_reverse.jl +++ b/test/interpreter/forward_over_reverse.jl @@ -48,6 +48,12 @@ end interp = Mooncake.get_interpreter(ForwardMode) args = map(TestUtils._deepcopy, x) + primal_args = map(TestUtils._deepcopy, x) + primal_value = f(primal_args...) + if !(primal_value isa IEEEFloat) + @test_broken false + continue + end rule = Mooncake.build_rrule(f, x...) sig = Tuple{ typeof(low_level_gradient),typeof(rule),_typeof(f),map(_typeof, args)... @@ -56,21 +62,21 @@ end rng = StableRNG(0xF0 + n) dirs = map(arg -> Mooncake.randn_tangent(rng, arg), args) if any(dir -> dir isa Mooncake.NoTangent, dirs) - @test_broken true + @test_broken false continue end dual_inputs = ( - Mooncake.Dual(low_level_gradient, Mooncake.zero_tangent(low_level_gradient)), - Mooncake.Dual(rule, Mooncake.zero_tangent(rule)), - Mooncake.Dual(f, Mooncake.zero_tangent(f)), + Mooncake.Dual(low_level_gradient, Mooncake.NoTangent()), + Mooncake.Dual(rule, Mooncake.NoTangent()), + Mooncake.Dual(f, Mooncake.NoTangent()), map((arg, dir) -> Mooncake.Dual(arg, dir), args, dirs)..., ) dual_result = grad_rule(dual_inputs...) pushforward = _as_tuple(Mooncake.tangent(dual_result)) # Use our own finite difference JVP implementation - grad_func = x -> _as_tuple(low_level_gradient(rule, f, x...)) + grad_func = (xs...) -> _as_tuple(low_level_gradient(rule, f, xs...)) fd_results_all = finite_diff_jvp(grad_func, args, dirs) # Check if any epsilon value gives a close match (following test_utils.jl pattern) From 7fc700774ff40782d1f657a387df52ca64aa24ba Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 14 Nov 2025 13:03:04 +0000 Subject: [PATCH 13/19] Revert "fix format" This reverts commit 933efc42daead095840d95d9c83996f0502ba107. --- test/interpreter/forward_over_reverse.jl | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/test/interpreter/forward_over_reverse.jl b/test/interpreter/forward_over_reverse.jl index 06c7368b28..ca487cba89 100644 --- a/test/interpreter/forward_over_reverse.jl +++ b/test/interpreter/forward_over_reverse.jl @@ -48,12 +48,6 @@ end interp = Mooncake.get_interpreter(ForwardMode) args = map(TestUtils._deepcopy, x) - primal_args = map(TestUtils._deepcopy, x) - primal_value = f(primal_args...) - if !(primal_value isa IEEEFloat) - @test_broken false - continue - end rule = Mooncake.build_rrule(f, x...) sig = Tuple{ typeof(low_level_gradient),typeof(rule),_typeof(f),map(_typeof, args)... @@ -62,21 +56,21 @@ end rng = StableRNG(0xF0 + n) dirs = map(arg -> Mooncake.randn_tangent(rng, arg), args) if any(dir -> dir isa Mooncake.NoTangent, dirs) - @test_broken false + @test_broken true continue end dual_inputs = ( - Mooncake.Dual(low_level_gradient, Mooncake.NoTangent()), - Mooncake.Dual(rule, Mooncake.NoTangent()), - Mooncake.Dual(f, Mooncake.NoTangent()), + Mooncake.Dual(low_level_gradient, Mooncake.zero_tangent(low_level_gradient)), + Mooncake.Dual(rule, Mooncake.zero_tangent(rule)), + Mooncake.Dual(f, Mooncake.zero_tangent(f)), map((arg, dir) -> Mooncake.Dual(arg, dir), args, dirs)..., ) dual_result = grad_rule(dual_inputs...) pushforward = _as_tuple(Mooncake.tangent(dual_result)) # Use our own finite difference JVP implementation - grad_func = (xs...) -> _as_tuple(low_level_gradient(rule, f, xs...)) + grad_func = x -> _as_tuple(low_level_gradient(rule, f, x...)) fd_results_all = finite_diff_jvp(grad_func, args, dirs) # Check if any epsilon value gives a close match (following test_utils.jl pattern) From 4dd6362eb3c6b8b5e6b76b1519ffab0dbdeb871a Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 18 Nov 2025 14:12:01 +0000 Subject: [PATCH 14/19] disable second order DI tests --- .../differentiation_interface/differentiation_interface.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index 3a375299ee..ea74e84b6d 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -6,5 +6,7 @@ using DifferentiationInterface, DifferentiationInterfaceTest using Mooncake: Mooncake test_differentiation( - [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; logging=true + [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; + excluded=SECOND_ORDER, + logging=true, ) From ca5f6fa1aa7e8fb2030c8d408aa43b699a10f5aa Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 19 Nov 2025 09:47:18 +0000 Subject: [PATCH 15/19] us DITests Hessian test --- .../differentiation_interface.jl | 15 ++- test/interpreter/forward_over_reverse.jl | 97 ------------------- test/runtests.jl | 1 - 3 files changed, 14 insertions(+), 99 deletions(-) delete mode 100644 test/interpreter/forward_over_reverse.jl diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index ea74e84b6d..512592db35 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -4,9 +4,22 @@ Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using DifferentiationInterface, DifferentiationInterfaceTest using Mooncake: Mooncake +using Test test_differentiation( [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; - excluded=SECOND_ORDER, + excluded=[:hvp, :second_derivative], # Enable hessian tests only logging=true, ) + +@testset "Mooncake second-order examples" begin + backend = SecondOrder(AutoMooncake(; config=nothing), AutoMooncake(; config=nothing)) + + # Sum: Hessian is zero + @test DI.hessian(sum, backend, [2.0]) == [0.0;;] + + # Rosenbrock 2D at [1.2, 1.2] + rosen(z) = (1.0 - z[1])^2 + 100.0 * (z[2] - z[1]^2)^2 + H = DI.hessian(rosen, backend, [1.2, 1.2]) + @test isapprox(H, [1250.0 -480.0; -480.0 200.0]; rtol=1e-10, atol=1e-12) +end diff --git a/test/interpreter/forward_over_reverse.jl b/test/interpreter/forward_over_reverse.jl deleted file mode 100644 index ca487cba89..0000000000 --- a/test/interpreter/forward_over_reverse.jl +++ /dev/null @@ -1,97 +0,0 @@ -using StableRNGs: StableRNG -using Base: IEEEFloat -using Mooncake: ForwardMode, _typeof, _add_to_primal, _scale, _diff - -function finite_diff_jvp(func, x, dx, unsafe_perturb::Bool=false) - ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] - - results = map(ε_list) do ε - x_plus = _add_to_primal(x, _scale(ε, dx), unsafe_perturb) - x_minus = _add_to_primal(x, _scale(-ε, dx), unsafe_perturb) - y_plus = func(x_plus...) - y_minus = func(x_minus...) - return _scale(1 / (2ε), _diff(y_plus, y_minus)) - end - - return results -end - -function low_level_gradient(rule, f, args...) - return Base.tail(Mooncake.value_and_gradient!!(rule, f, args...)[2]) -end - -_as_tuple(x::Tuple) = x -_as_tuple(x) = (x,) - -function _isapprox_nested(a, b; atol=1e-8, rtol=1e-6) - if a isa Number && b isa Number - return isapprox(a, b; atol, rtol) - elseif a isa AbstractArray && b isa AbstractArray - return isapprox(a, b; atol, rtol) - elseif a isa Tuple && b isa Tuple - return length(a) == length(b) && - all(_isapprox_nested(ai, bi; atol, rtol) for (ai, bi) in zip(a, b)) - else - return a == b - end -end - -@testset "forward-over-reverse Hessian AD" begin - @testset "$(_typeof((f, x...)))" for (n, (interface_only, perf_flag, bnds, f, x...)) in - collect( - enumerate(TestResources.generate_test_functions()) - ) - # Skip interface-only tests as they don't have implementations - interface_only && continue - - @info "$n: $(_typeof((f, x...)))" - - interp = Mooncake.get_interpreter(ForwardMode) - args = map(TestUtils._deepcopy, x) - rule = Mooncake.build_rrule(f, x...) - sig = Tuple{ - typeof(low_level_gradient),typeof(rule),_typeof(f),map(_typeof, args)... - } - grad_rule = Mooncake.build_frule(interp, sig) - rng = StableRNG(0xF0 + n) - dirs = map(arg -> Mooncake.randn_tangent(rng, arg), args) - if any(dir -> dir isa Mooncake.NoTangent, dirs) - @test_broken true - continue - end - - dual_inputs = ( - Mooncake.Dual(low_level_gradient, Mooncake.zero_tangent(low_level_gradient)), - Mooncake.Dual(rule, Mooncake.zero_tangent(rule)), - Mooncake.Dual(f, Mooncake.zero_tangent(f)), - map((arg, dir) -> Mooncake.Dual(arg, dir), args, dirs)..., - ) - dual_result = grad_rule(dual_inputs...) - pushforward = _as_tuple(Mooncake.tangent(dual_result)) - - # Use our own finite difference JVP implementation - grad_func = x -> _as_tuple(low_level_gradient(rule, f, x...)) - fd_results_all = finite_diff_jvp(grad_func, args, dirs) - - # Check if any epsilon value gives a close match (following test_utils.jl pattern) - # Convert each result to tuple form for comparison - fd_results_tuples = map(res -> _as_tuple(res), fd_results_all) - - # Check which epsilon values give close results - isapprox_results = map(fd_results_tuples) do fd_ref - return all(_isapprox_nested(pf, fd) # uses default atol=1e-8, rtol=1e-6 - for (pf, fd) in zip(pushforward, fd_ref)) - end - - @test length(pushforward) == length(first(fd_results_tuples)) - # At least one epsilon value should give a close result - if !any(isapprox_results) - # If none match, display values for debugging (like test_utils.jl does) - println("No epsilon gave close result. AD vs FD for each epsilon:") - for (i, fd_ref) in enumerate(fd_results_tuples) - println(" ε[$(i)]: AD=$pushforward, FD=$fd_ref") - end - end - @test any(isapprox_results) - end -end diff --git a/test/runtests.jl b/test/runtests.jl index dfc8288ed5..d8e05c89ba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,7 +18,6 @@ include("front_matter.jl") include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "forward_mode.jl")) include(joinpath("interpreter", "reverse_mode.jl")) - include(joinpath("interpreter", "forward_over_reverse.jl")) end include("tools_for_rules.jl") include("interface.jl") From c383c05357eefd913833f682fb50d0053a954fe5 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 19 Nov 2025 17:32:30 +0000 Subject: [PATCH 16/19] bring back old tests --- .../differentiation_interface.jl | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index 512592db35..22257ad440 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -3,23 +3,37 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using DifferentiationInterface, DifferentiationInterfaceTest +import DifferentiationInterface as DI using Mooncake: Mooncake using Test test_differentiation( [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; - excluded=[:hvp, :second_derivative], # Enable hessian tests only + excluded=SECOND_ORDER, logging=true, ) -@testset "Mooncake second-order examples" begin - backend = SecondOrder(AutoMooncake(; config=nothing), AutoMooncake(; config=nothing)) +# Test Hessian computation using forward-over-reverse (it hangs) +# test_differentiation( +# [SecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))]; +# excluded=vcat(FIRST_ORDER, [:hvp, :second_derivative]), # Only test hessian +# logging=true, +# ) + +@testset "Mooncake Hessian tests" begin + backend = SecondOrder( + AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing) + ) # Sum: Hessian is zero - @test DI.hessian(sum, backend, [2.0]) == [0.0;;] + @testset "sum" begin + @test DI.hessian(sum, backend, [2.0]) == [0.0] + end # Rosenbrock 2D at [1.2, 1.2] - rosen(z) = (1.0 - z[1])^2 + 100.0 * (z[2] - z[1]^2)^2 - H = DI.hessian(rosen, backend, [1.2, 1.2]) - @test isapprox(H, [1250.0 -480.0; -480.0 200.0]; rtol=1e-10, atol=1e-12) + @testset "Rosenbrock" begin + rosen(z) = (1.0 - z[1])^2 + 100.0 * (z[2] - z[1]^2)^2 + H = DI.hessian(rosen, backend, [1.2, 1.2]) + @test isapprox(H, [1250.0 -480.0; -480.0 200.0]; rtol=1e-10, atol=1e-12) + end end From 8f96ece4b26b79143fd2e7d51a709ae6cc99eb28 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 1 Dec 2025 21:10:01 +0000 Subject: [PATCH 17/19] pushing the envelope --- src/rrules/low_level_maths.jl | 17 ++++++++++++++ src/rrules/memory.jl | 10 ++++++++ .../differentiation_interface.jl | 23 ++++++++++++++----- 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index e156811abd..c8fc110746 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -71,6 +71,23 @@ @from_chainrules MinimalCtx Tuple{typeof(deg2rad),IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(rad2deg),IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(^),P,P} where {P<:IEEEFloat} + +# ^(Float, Int) for literal integer powers like x^4. Forward and reverse modes +# individually derive rules fine, but forward-over-reverse (Hessian) compilation +# hangs without this as it tries to differentiate through Julia's power implementation. +@is_primitive MinimalCtx Tuple{typeof(^),P,Integer} where {P<:IEEEFloat} +function frule!!(::Dual{typeof(^)}, x::Dual{P}, p::Dual{<:Integer}) where {P<:IEEEFloat} + _x, _p = primal(x), primal(p) + return Dual(_x^_p, _p * _x^(_p - 1) * tangent(x)) +end +function rrule!!( + ::CoDual{typeof(^)}, x::CoDual{P}, p::CoDual{<:Integer} +) where {P<:IEEEFloat} + _x, _p = primal(x), primal(p) + pow_int_pb(dy::P) = NoRData(), dy * _p * _x^(_p - 1), NoRData() + return zero_fcodual(_x^_p), pow_int_pb +end + @from_chainrules MinimalCtx Tuple{typeof(atan),P,P} where {P<:IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(max),P,P} where {P<:IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(min),P,P} where {P<:IEEEFloat} diff --git a/src/rrules/memory.jl b/src/rrules/memory.jl index 7329e89bc4..2115a47d34 100644 --- a/src/rrules/memory.jl +++ b/src/rrules/memory.jl @@ -605,6 +605,16 @@ function rrule!!( return CoDual(x, dx), NoPullback((NoRData(), NoRData(), NoRData())) end +# Core.memorynew is called inside rrule!! for Memory allocation. Forward-over-reverse +# (Hessian) needs this frule when forward mode differentiates through reverse mode code. +@inline function frule!!( + ::Dual{typeof(Core.memorynew)}, ::Dual{Type{Memory{P}}}, n::Dual{Int} +) where {P} + x = Core.memorynew(Memory{P}, primal(n)) + dx = zero_tangent_internal(x, NoCache()) + return Dual(x, dx) +end + function rrule!!( ::CoDual{typeof(_new_)}, ::CoDual{Type{MemoryRef{P}}}, diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index 22257ad440..023b033175 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -13,12 +13,17 @@ test_differentiation( logging=true, ) -# Test Hessian computation using forward-over-reverse (it hangs) -# test_differentiation( -# [SecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))]; -# excluded=vcat(FIRST_ORDER, [:hvp, :second_derivative]), # Only test hessian -# logging=true, -# ) +# Test Hessian computation using forward-over-reverse with DITest scenarios. +# Using linalg=false to select loop-based test functions instead of the default +# versions that use broadcasting and linear algebra ops (vec, transpose). +# Broadcasting in forward-over-reverse mode causes compilation hangs due to +# complex nested types that overwhelm Julia's type inference. +test_differentiation( + [SecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))]; + scenarios=default_scenarios(; linalg=false), + excluded=vcat(FIRST_ORDER, [:hvp, :second_derivative]), + logging=true, +) @testset "Mooncake Hessian tests" begin backend = SecondOrder( @@ -36,4 +41,10 @@ test_differentiation( H = DI.hessian(rosen, backend, [1.2, 1.2]) @test isapprox(H, [1250.0 -480.0; -480.0 200.0]; rtol=1e-10, atol=1e-12) end + + # Test higher integer powers (fixed by adding frule for ^(Float, Int)) + @testset "higher powers" begin + @test DI.hessian(x -> x[1]^4, backend, [2.0]) ≈ [48.0] + @test DI.hessian(x -> x[1]^6, backend, [2.0]) ≈ [480.0] + end end From 726ce5d313fadaf39b83de95a46c78aa944389fa Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 4 Dec 2025 14:15:54 +0000 Subject: [PATCH 18/19] add DI extension, seems DI wrap reversemode as a closure --- Project.toml | 3 + ext/MooncakeDifferentiationInterfaceExt.jl | 147 +++++++++++++++ src/Mooncake.jl | 1 + src/dual.jl | 125 +++++++++++++ src/fwds_rvs_data.jl | 10 + src/interface.jl | 36 +++- src/interpreter/forward_mode.jl | 7 - .../avoiding_non_differentiable_code.jl | 2 - src/rrules/dual_arithmetic.jl | 172 ++++++++++++++++++ src/rrules/low_level_maths.jl | 17 -- src/rrules/memory.jl | 10 - .../differentiation_interface.jl | 5 - 12 files changed, 490 insertions(+), 45 deletions(-) create mode 100644 ext/MooncakeDifferentiationInterfaceExt.jl create mode 100644 src/rrules/dual_arithmetic.jl diff --git a/Project.toml b/Project.toml index 950b690d43..31deb7bf64 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" @@ -33,6 +34,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" [extensions] MooncakeAllocCheckExt = "AllocCheck" MooncakeCUDAExt = "CUDA" +MooncakeDifferentiationInterfaceExt = "DifferentiationInterface" MooncakeDynamicExpressionsExt = "DynamicExpressions" MooncakeFluxExt = "Flux" MooncakeFunctionWrappersExt = "FunctionWrappers" @@ -51,6 +53,7 @@ CUDA = "5" ChainRules = "1.71.0" ChainRulesCore = "1" DiffTests = "0.1" +DifferentiationInterface = "0.7" DispatchDoctor = "0.4.26" DynamicExpressions = "2" ExprTools = "0.1" diff --git a/ext/MooncakeDifferentiationInterfaceExt.jl b/ext/MooncakeDifferentiationInterfaceExt.jl new file mode 100644 index 0000000000..600ee38ff5 --- /dev/null +++ b/ext/MooncakeDifferentiationInterfaceExt.jl @@ -0,0 +1,147 @@ +module MooncakeDifferentiationInterfaceExt + +using Mooncake: + Mooncake, @is_primitive, MinimalCtx, ForwardMode, Dual, primal, tangent, NoTangent +import DifferentiationInterface as DI + +# Mark shuffled_gradient as forward-mode primitive to avoid expensive type inference hang. +# This prevents build_frule from trying to derive rules for the complex gradient closure. +@is_primitive MinimalCtx ForwardMode Tuple{typeof(DI.shuffled_gradient),Vararg} +@is_primitive MinimalCtx ForwardMode Tuple{typeof(DI.shuffled_gradient!),Vararg} + +# Helper to create Dual array from primal and tangent arrays +_make_dual_array(x::AbstractArray, dx::AbstractArray) = Dual.(x, dx) +_make_dual_array(x, dx) = Dual(x, dx) + +# Helper to extract primal and tangent from Dual array +_extract_primals(arr::AbstractArray{<:Dual}) = primal.(arr) +_extract_primals(d::Dual) = primal(d) +_extract_tangents(arr::AbstractArray{<:Dual}) = tangent.(arr) +_extract_tangents(d::Dual) = tangent(d) + +# frule for shuffled_gradient without prep +# shuffled_gradient(x, f, backend, rewrap, contexts...) -> gradient(f, backend, x, contexts...) +function Mooncake.frule!!( + ::Dual{typeof(DI.shuffled_gradient)}, + x_dual::Dual, + f_dual::Dual, + backend_dual::Dual, + rewrap_dual::Dual, + context_duals::Vararg{Dual}, +) + # Extract primals and tangents + x = primal(x_dual) + dx = tangent(x_dual) + f = primal(f_dual) + backend = primal(backend_dual) + rewrap = primal(rewrap_dual) + contexts = map(d -> primal(d), context_duals) + + # Create Dual inputs: each element is Dual(x[i], dx[i]) + # This allows the Hvp to be computed via forward-over-reverse + x_with_duals = _make_dual_array(x, dx) + + # Call gradient with Dual inputs + # Since Dual{Float64,Float64} is self-tangent, reverse mode handles it correctly + grad_duals = DI.shuffled_gradient(x_with_duals, f, backend, rewrap, contexts...) + + # Extract primal (gradient) and tangent (Hvp) from the Dual outputs + grad_primal = _extract_primals(grad_duals) + grad_tangent = _extract_tangents(grad_duals) + + return Dual(grad_primal, grad_tangent) +end + +# frule for shuffled_gradient with prep +function Mooncake.frule!!( + ::Dual{typeof(DI.shuffled_gradient)}, + x_dual::Dual, + f_dual::Dual, + prep_dual::Dual, + backend_dual::Dual, + rewrap_dual::Dual, + context_duals::Vararg{Dual}, +) + x = primal(x_dual) + dx = tangent(x_dual) + f = primal(f_dual) + prep = primal(prep_dual) + backend = primal(backend_dual) + rewrap = primal(rewrap_dual) + contexts = map(d -> primal(d), context_duals) + + x_with_duals = _make_dual_array(x, dx) + grad_duals = DI.shuffled_gradient(x_with_duals, f, prep, backend, rewrap, contexts...) + + grad_primal = _extract_primals(grad_duals) + grad_tangent = _extract_tangents(grad_duals) + + return Dual(grad_primal, grad_tangent) +end + +# frule for shuffled_gradient! (in-place version) +function Mooncake.frule!!( + ::Dual{typeof(DI.shuffled_gradient!)}, + grad_dual::Dual, + x_dual::Dual, + f_dual::Dual, + backend_dual::Dual, + rewrap_dual::Dual, + context_duals::Vararg{Dual}, +) + grad = primal(grad_dual) + dgrad = tangent(grad_dual) # Tangent storage for gradient (where Hvp goes) + x = primal(x_dual) + dx = tangent(x_dual) + f = primal(f_dual) + backend = primal(backend_dual) + rewrap = primal(rewrap_dual) + contexts = map(d -> primal(d), context_duals) + + x_with_duals = _make_dual_array(x, dx) + # Allocate Dual buffer for in-place gradient + grad_duals = _make_dual_array(grad, similar(grad)) + DI.shuffled_gradient!(grad_duals, x_with_duals, f, backend, rewrap, contexts...) + + # Copy primal (gradient) back to grad + grad .= _extract_primals(grad_duals) + # Copy tangent (Hvp) back to dgrad + dgrad .= _extract_tangents(grad_duals) + + return Dual(nothing, NoTangent()) +end + +# frule for shuffled_gradient! with prep +function Mooncake.frule!!( + ::Dual{typeof(DI.shuffled_gradient!)}, + grad_dual::Dual, + x_dual::Dual, + f_dual::Dual, + prep_dual::Dual, + backend_dual::Dual, + rewrap_dual::Dual, + context_duals::Vararg{Dual}, +) + grad = primal(grad_dual) + dgrad = tangent(grad_dual) # Tangent storage for gradient (where Hvp goes) + x = primal(x_dual) + dx = tangent(x_dual) + f = primal(f_dual) + prep = primal(prep_dual) + backend = primal(backend_dual) + rewrap = primal(rewrap_dual) + contexts = map(d -> primal(d), context_duals) + + x_with_duals = _make_dual_array(x, dx) + grad_duals = _make_dual_array(grad, similar(grad)) + DI.shuffled_gradient!(grad_duals, x_with_duals, f, prep, backend, rewrap, contexts...) + + # Copy primal (gradient) back to grad + grad .= _extract_primals(grad_duals) + # Copy tangent (Hvp) back to dgrad + dgrad .= _extract_tangents(grad_duals) + + return Dual(nothing, NoTangent()) +end + +end diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 87ef7903f9..b7f7318331 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -151,6 +151,7 @@ else include(joinpath("rrules", "array_legacy.jl")) end include(joinpath("rrules", "performance_patches.jl")) +include(joinpath("rrules", "dual_arithmetic.jl")) include("interface.jl") include("config.jl") diff --git a/src/dual.jl b/src/dual.jl index 65cf53532a..e51a5786c7 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -55,3 +55,128 @@ verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x) function Dual(x::Type{P}, dx::NoTangent) where {P} return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(x, dx) end + +# Dual of numeric types is self-tangent +@inline tangent_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual{P,T} + +@inline zero_tangent_internal( + x::Dual{P,T}, ::MaybeCache +) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T)) + +@inline function randn_tangent_internal( + rng::AbstractRNG, x::Dual{P,T}, ::MaybeCache +) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(randn(rng, P), randn(rng, T)) +end + +@inline function increment!!(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) +end + +@inline set_to_zero_internal!!( + ::SetToZeroCache, x::Dual{P,T} +) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T)) + +@inline function increment_internal!!( + ::IncCache, x::Dual{P,T}, y::Dual{P,T} +) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) +end + +Base.one(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(one(P), zero(T)) +function Base.one(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(one(primal(x)), zero(tangent(x))) +end + +# Arithmetic operations +function Base.:+(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + y, tangent(x)) +end +function Base.:+(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x + primal(y), tangent(y)) +end +function Base.:+(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) +end + +# Subtraction +Base.:-(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(-primal(x), -tangent(x)) +function Base.:-(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) - y, tangent(x)) +end +function Base.:-(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x - primal(y), -tangent(y)) +end +function Base.:-(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) - primal(y), tangent(x) - tangent(y)) +end + +# Multiplication (product rule) +function Base.:*(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) * y, tangent(x) * y) +end +function Base.:*(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x * primal(y), x * tangent(y)) +end +function Base.:*(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) * primal(y), primal(x) * tangent(y) + tangent(x) * primal(y)) +end +function Base.:*(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) * y, tangent(x) * y) +end +function Base.:*(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x * primal(y), x * tangent(y)) +end + +# Division (quotient rule) +function Base.:/(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) / y, tangent(x) / y) +end +function Base.:/(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x / primal(y), -x * tangent(y) / primal(y)^2) +end +function Base.:/(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual( + primal(x) / primal(y), + (tangent(x) * primal(y) - primal(x) * tangent(y)) / primal(y)^2, + ) +end + +# Power (chain rule) +function Base.:^(x::Dual{P,T}, n::Integer) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x)^n, n * primal(x)^(n - 1) * tangent(x)) +end +function Base.:^(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x)^y, y * primal(x)^(y - 1) * tangent(x)) +end + +# Comparison (use primal for comparisons) +Base.:<(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) < y +Base.:<(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x < primal(y) +function Base.:<(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return primal(x) < primal(y) +end +Base.:>(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) > y +Base.:>(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x > primal(y) +function Base.:>(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return primal(x) > primal(y) +end +Base.:<=(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) <= y +Base.:<=(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x <= primal(y) +function Base.:<=(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return primal(x) <= primal(y) +end +Base.:>=(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} = primal(x) >= y +Base.:>=(x::P, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x >= primal(y) +function Base.:>=(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return primal(x) >= primal(y) +end + +# Conversion and promotion +Base.convert(::Type{Dual{P,T}}, x::P) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(x, zero(T)) +function Base.promote_rule(::Type{Dual{P,T}}, ::Type{P}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual{P,T} +end + +LinearAlgebra.transpose(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x +LinearAlgebra.adjoint(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = x diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index e5b46319d0..3d303e6902 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -159,6 +159,7 @@ fdata_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?")) @foldable fdata_type(::Type{Union{}}) = Union{} fdata_type(::Type{T}) where {T<:IEEEFloat} = NoFData +fdata_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = NoFData function fdata_type(::Type{PossiblyUninitTangent{T}}) where {T} Tfields = fdata_type(T) @@ -437,6 +438,7 @@ rdata_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?")) @foldable rdata_type(::Type{Union{}}) = Union{} rdata_type(::Type{T}) where {T<:IEEEFloat} = T +rdata_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual{P,T} function rdata_type(::Type{PossiblyUninitTangent{T}}) where {T} return PossiblyUninitTangent{rdata_type(T)} @@ -587,6 +589,7 @@ Given value `p`, return the zero element associated to its reverse data type. zero_rdata(p) zero_rdata(p::IEEEFloat) = zero(p) +zero_rdata(p::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(zero(P), zero(T)) @generated function zero_rdata(p::P) where {P} Rs = rdata_field_types_exprs(P) @@ -654,6 +657,9 @@ obtained from `P` alone. end @foldable can_produce_zero_rdata_from_type(::Type{<:IEEEFloat}) = true +@foldable can_produce_zero_rdata_from_type( + ::Type{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} = true @foldable can_produce_zero_rdata_from_type(::Type{<:Type}) = true @@ -737,6 +743,9 @@ function zero_rdata_from_type(::Type{P}) where {P<:NamedTuple} end zero_rdata_from_type(::Type{P}) where {P<:IEEEFloat} = zero(P) +function zero_rdata_from_type(::Type{Dual{P,T}}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(zero(P), zero(T)) +end zero_rdata_from_type(::Type{<:Type}) = NoRData() @@ -951,6 +960,7 @@ Reconstruct the tangent `t` for which `fdata(t) == f` and `rdata(t) == r`. """ tangent(::NoFData, ::NoRData) = NoTangent() tangent(::NoFData, r::IEEEFloat) = r +tangent(::NoFData, r::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = r tangent(f::Array, ::NoRData) = f # Tuples diff --git a/src/interface.jl b/src/interface.jl index 0d3623ca9c..9786371fb3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -106,7 +106,8 @@ function __value_and_gradient!!(rule::R, fx::Vararg{CoDual,N}) where {R,N} __verify_sig(rule, fx_fwds) out, pb!! = rule(fx_fwds...) y = primal(out) - y isa IEEEFloat || throw_val_and_grad_ret_type_error(y) + (y isa IEEEFloat || y isa Dual{<:IEEEFloat,<:IEEEFloat}) || + throw_val_and_grad_ret_type_error(y) return y, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(one(y))) end @@ -543,7 +544,9 @@ The API guarantees that tangents are initialized at zero before the first autodi rule = build_rrule(fx...; kwargs...) tangents = map(zero_tangent, fx) y, rvs!! = rule(map((x, dx) -> CoDual(x, fdata(dx)), fx, tangents)...) - primal(y) isa IEEEFloat || throw_val_and_grad_ret_type_error(primal(y)) + _y = primal(y) + (_y isa IEEEFloat || _y isa Dual{<:IEEEFloat,<:IEEEFloat}) || + throw_val_and_grad_ret_type_error(_y) rvs!!(zero_tangent(primal(y))) # run reverse-pass to reset stacks + state return Cache(rule, nothing, tangents) end @@ -617,8 +620,33 @@ in `f` and `x`. """ value_and_derivative!!(rule::R, fx::Vararg{Dual,N}) where {R,N} = rule(fx...) -# Avoid differentiating cache constructors in forward mode to prevent -# forward-over-reverse from descending into interpreter/caches. @zero_derivative MinimalCtx Tuple{typeof(prepare_pullback_cache),Vararg} ForwardMode @zero_derivative MinimalCtx Tuple{typeof(prepare_gradient_cache),Vararg} ForwardMode @zero_derivative MinimalCtx Tuple{typeof(prepare_derivative_cache),Vararg} ForwardMode + +@is_primitive MinimalCtx Tuple{typeof(value_and_gradient!!),Cache,Vararg} ForwardMode + +function frule!!( + ::Dual{typeof(value_and_gradient!!)}, + cache_dual::Dual{<:Cache}, + f_dual::Dual, + x_duals::Vararg{Dual}, +) + # Extract primals and tangents + cache = primal(cache_dual) + f = primal(f_dual) + xs = map(primal, x_duals) + dxs = map(tangent, x_duals) + + y, grads = value_and_gradient!!(cache, f, xs...) + + df = tangent(f_dual) + dy = df + for (g, dx) in zip(grads, dxs) + dy = dy + g * dx + end + + dgrads = map(zero_tangent, grads) + + return Dual(y, dy), Dual(grads, dgrads) +end diff --git a/src/interpreter/forward_mode.jl b/src/interpreter/forward_mode.jl index 0629f1296f..150038ec03 100644 --- a/src/interpreter/forward_mode.jl +++ b/src/interpreter/forward_mode.jl @@ -324,8 +324,6 @@ end function modify_fwd_ad_stmts!( stmt::UpsilonNode, dual_ir::IRCode, ssa::SSAValue, captures::Vector{Any}, ::DualInfo ) - # In some compiler-generated UpsilonNodes the `val` field can be undefined; skip safely. - isdefined(stmt, :val) || return nothing if !(stmt.val isa Union{Argument,SSAValue}) stmt = UpsilonNode(uninit_dual(get_const_primal_value(stmt.val))) end @@ -437,11 +435,6 @@ function modify_fwd_ad_stmts!( # Leave this node alone elseif isexpr(stmt, :pop_exception) # Leave this node alone - elseif isexpr(stmt, :the_exception) - # Preserve the primal exception object but give it a zero tangent. - inst = CC.NewInstruction(get_ir(info.primal_ir, ssa)) - ex_ssa = CC.insert_node!(dual_ir, ssa, inst, ATTACH_BEFORE) - replace_call!(dual_ir, ssa, Expr(:call, zero_dual, ex_ssa)) else msg = "Expressions of type `:$(stmt.head)` are not yet supported in forward mode" throw(ArgumentError(msg)) diff --git a/src/rrules/avoiding_non_differentiable_code.jl b/src/rrules/avoiding_non_differentiable_code.jl index df750bc1cf..85c25c4930 100644 --- a/src/rrules/avoiding_non_differentiable_code.jl +++ b/src/rrules/avoiding_non_differentiable_code.jl @@ -107,10 +107,8 @@ end } ) -# # Avoid differentiating Mooncake's rule construction in forward mode # This prevents forward-over-reverse from descending into kw-wrapper exceptions and caches. -# @zero_derivative MinimalCtx Tuple{typeof(build_rrule),Vararg} ForwardMode @zero_derivative MinimalCtx Tuple{typeof(Core.kwcall),NamedTuple,typeof(build_rrule),Vararg} ForwardMode diff --git a/src/rrules/dual_arithmetic.jl b/src/rrules/dual_arithmetic.jl new file mode 100644 index 0000000000..6b091b5834 --- /dev/null +++ b/src/rrules/dual_arithmetic.jl @@ -0,0 +1,172 @@ +@inline function _dual_add_pullback(dy::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return NoRData(), dy, dy +end + +@inline function _dual_sub_pullback(dy::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return NoRData(), dy, Dual(-primal(dy), -tangent(dy)) +end + +@inline function _dual_neg_pullback(dy::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return NoRData(), Dual(-primal(dy), -tangent(dy)) +end + +@is_primitive MinimalCtx Tuple{ + typeof(+),Dual{P,T},Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{Dual{P,T}}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + return CoDual(z, NoFData()), _dual_add_pullback +end + +@is_primitive MinimalCtx Tuple{typeof(+),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + pb!! = dy -> (NoRData(), dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(+),P,Dual{P,T}} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{P}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + pb!! = dy -> (NoRData(), NoRData(), dy) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(-),Dual{P,T}} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = -primal(x) + return CoDual(z, NoFData()), _dual_neg_pullback +end + +@is_primitive MinimalCtx Tuple{ + typeof(-),Dual{P,T},Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{Dual{P,T}}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + return CoDual(z, NoFData()), _dual_sub_pullback +end + +@is_primitive MinimalCtx Tuple{typeof(-),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + pb!! = dy -> (NoRData(), dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(-),P,Dual{P,T}} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{P}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + pb!! = dy -> (NoRData(), NoRData(), Dual(-primal(dy), -tangent(dy))) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(*),Dual{P,T},Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{Dual{P,T}}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + function mul_dual_dual_pb!!(dy::Dual{P,T}) + dx = py * dy + dy_out = px * dy + return NoRData(), dx, dy_out + end + return CoDual(z, NoFData()), mul_dual_dual_pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(*),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + pb!! = dy -> (NoRData(), py * dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(*),P,Dual{P,T}} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + pb!! = dy -> (NoRData(), NoRData(), px * dy) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(/),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(/)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px / py + pb!! = dy -> (NoRData(), dy / py, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(*),Integer,Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{<:Integer}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + pb!! = dy -> (NoRData(), NoRData(), px * dy) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(*),Dual{P,T},Integer +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{Dual{P,T}}, y::CoDual{<:Integer} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px * py + pb!! = dy -> (NoRData(), py * dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(^),Dual{P,T},Int} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(^)}, x::CoDual{Dual{P,T}}, n::CoDual{Int} +) where {P<:IEEEFloat,T<:IEEEFloat} + px = primal(x) + pn = primal(n) + z = px^pn + function pow_dual_int_pb!!(dy::Dual{P,T}) + dx = pn * px^(pn - 1) * dy + return NoRData(), dx, NoRData() + end + return CoDual(z, NoFData()), pow_dual_int_pb!! +end + +@is_primitive MinimalCtx Tuple{typeof(^),Dual{P,T},P} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(^)}, x::CoDual{Dual{P,T}}, y::CoDual{P} +) where {P<:IEEEFloat,T<:IEEEFloat} + px, py = primal(x), primal(y) + z = px^py + function pow_dual_float_pb!!(dy::Dual{P,T}) + dx = py * px^(py - one(P)) * dy + return NoRData(), dx, NoRData() + end + return CoDual(z, NoFData()), pow_dual_float_pb!! +end diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index c8fc110746..e156811abd 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -71,23 +71,6 @@ @from_chainrules MinimalCtx Tuple{typeof(deg2rad),IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(rad2deg),IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(^),P,P} where {P<:IEEEFloat} - -# ^(Float, Int) for literal integer powers like x^4. Forward and reverse modes -# individually derive rules fine, but forward-over-reverse (Hessian) compilation -# hangs without this as it tries to differentiate through Julia's power implementation. -@is_primitive MinimalCtx Tuple{typeof(^),P,Integer} where {P<:IEEEFloat} -function frule!!(::Dual{typeof(^)}, x::Dual{P}, p::Dual{<:Integer}) where {P<:IEEEFloat} - _x, _p = primal(x), primal(p) - return Dual(_x^_p, _p * _x^(_p - 1) * tangent(x)) -end -function rrule!!( - ::CoDual{typeof(^)}, x::CoDual{P}, p::CoDual{<:Integer} -) where {P<:IEEEFloat} - _x, _p = primal(x), primal(p) - pow_int_pb(dy::P) = NoRData(), dy * _p * _x^(_p - 1), NoRData() - return zero_fcodual(_x^_p), pow_int_pb -end - @from_chainrules MinimalCtx Tuple{typeof(atan),P,P} where {P<:IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(max),P,P} where {P<:IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(min),P,P} where {P<:IEEEFloat} diff --git a/src/rrules/memory.jl b/src/rrules/memory.jl index 2115a47d34..7329e89bc4 100644 --- a/src/rrules/memory.jl +++ b/src/rrules/memory.jl @@ -605,16 +605,6 @@ function rrule!!( return CoDual(x, dx), NoPullback((NoRData(), NoRData(), NoRData())) end -# Core.memorynew is called inside rrule!! for Memory allocation. Forward-over-reverse -# (Hessian) needs this frule when forward mode differentiates through reverse mode code. -@inline function frule!!( - ::Dual{typeof(Core.memorynew)}, ::Dual{Type{Memory{P}}}, n::Dual{Int} -) where {P} - x = Core.memorynew(Memory{P}, primal(n)) - dx = zero_tangent_internal(x, NoCache()) - return Dual(x, dx) -end - function rrule!!( ::CoDual{typeof(_new_)}, ::CoDual{Type{MemoryRef{P}}}, diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index 023b033175..7cd69c2701 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -14,13 +14,8 @@ test_differentiation( ) # Test Hessian computation using forward-over-reverse with DITest scenarios. -# Using linalg=false to select loop-based test functions instead of the default -# versions that use broadcasting and linear algebra ops (vec, transpose). -# Broadcasting in forward-over-reverse mode causes compilation hangs due to -# complex nested types that overwhelm Julia's type inference. test_differentiation( [SecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))]; - scenarios=default_scenarios(; linalg=false), excluded=vcat(FIRST_ORDER, [:hvp, :second_derivative]), logging=true, ) From c5bb54b2b3f5ed3f4ec2a211f2d6fc00db3def2d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 5 Dec 2025 17:44:53 +0000 Subject: [PATCH 19/19] fix and add Tim Holy's example --- src/dual.jl | 12 ++ src/rrules/dual_arithmetic.jl | 44 ++++++ .../differentiation_interface.jl | 145 ++++++++++++++++++ 3 files changed, 201 insertions(+) diff --git a/src/dual.jl b/src/dual.jl index e51a5786c7..34e563bb08 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -98,6 +98,12 @@ end function Base.:+(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) end +function Base.:+(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) + y, tangent(x)) +end +function Base.:+(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x + primal(y), tangent(y)) +end # Subtraction Base.:-(x::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} = Dual(-primal(x), -tangent(x)) @@ -110,6 +116,12 @@ end function Base.:-(x::Dual{P,T}, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} return Dual(primal(x) - primal(y), tangent(x) - tangent(y)) end +function Base.:-(x::Dual{P,T}, y::Integer) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(primal(x) - y, tangent(x)) +end +function Base.:-(x::Integer, y::Dual{P,T}) where {P<:IEEEFloat,T<:IEEEFloat} + return Dual(x - primal(y), -tangent(y)) +end # Multiplication (product rule) function Base.:*(x::Dual{P,T}, y::P) where {P<:IEEEFloat,T<:IEEEFloat} diff --git a/src/rrules/dual_arithmetic.jl b/src/rrules/dual_arithmetic.jl index 6b091b5834..10ce53949a 100644 --- a/src/rrules/dual_arithmetic.jl +++ b/src/rrules/dual_arithmetic.jl @@ -38,6 +38,28 @@ function rrule!!( return CoDual(z, NoFData()), pb!! end +@is_primitive MinimalCtx Tuple{ + typeof(+),Dual{P,T},Integer +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{Dual{P,T}}, y::CoDual{<:Integer} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + pb!! = dy -> (NoRData(), dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(+),Integer,Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{<:Integer}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) + primal(y) + pb!! = dy -> (NoRData(), NoRData(), dy) + return CoDual(z, NoFData()), pb!! +end + @is_primitive MinimalCtx Tuple{typeof(-),Dual{P,T}} where {P<:IEEEFloat,T<:IEEEFloat} function rrule!!( ::CoDual{typeof(-)}, x::CoDual{Dual{P,T}} @@ -74,6 +96,28 @@ function rrule!!( return CoDual(z, NoFData()), pb!! end +@is_primitive MinimalCtx Tuple{ + typeof(-),Dual{P,T},Integer +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{Dual{P,T}}, y::CoDual{<:Integer} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + pb!! = dy -> (NoRData(), dy, NoRData()) + return CoDual(z, NoFData()), pb!! +end + +@is_primitive MinimalCtx Tuple{ + typeof(-),Integer,Dual{P,T} +} where {P<:IEEEFloat,T<:IEEEFloat} +function rrule!!( + ::CoDual{typeof(-)}, x::CoDual{<:Integer}, y::CoDual{Dual{P,T}} +) where {P<:IEEEFloat,T<:IEEEFloat} + z = primal(x) - primal(y) + pb!! = dy -> (NoRData(), NoRData(), Dual(-primal(dy), -tangent(dy))) + return CoDual(z, NoFData()), pb!! +end + @is_primitive MinimalCtx Tuple{ typeof(*),Dual{P,T},Dual{P,T} } where {P<:IEEEFloat,T<:IEEEFloat} diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index 7cd69c2701..7b81715a44 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -42,4 +42,149 @@ test_differentiation( @test DI.hessian(x -> x[1]^4, backend, [2.0]) ≈ [48.0] @test DI.hessian(x -> x[1]^6, backend, [2.0]) ≈ [480.0] end + + @testset "https://github.com/chalk-lab/Mooncake.jl/issues/632" begin + function gams_objective(x) + return ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + ( + x[1] * + x[1] + + x[10] * + x[10] + ) * + ( + x[1] * + x[1] + + x[10] * + x[10] + ) - + 4 * + x[1] + ) + + 3 + ) + + ( + x[2] * + x[2] + + x[10] * + x[10] + ) * + ( + x[2] * + x[2] + + x[10] * + x[10] + ) + ) - + 4 * + x[2] + ) + + 3 + ) + + ( + x[3] * + x[3] + + x[10] * + x[10] + ) * + ( + x[3] * + x[3] + + x[10] * + x[10] + ) + ) - + 4 * + x[3] + ) + + 3 + ) + + ( + x[4] * + x[4] + + x[10] * + x[10] + ) * ( + x[4] * + x[4] + + x[10] * + x[10] + ) + ) - + 4 * x[4] + ) + 3 + ) + + ( + x[5] * x[5] + + x[10] * x[10] + ) * ( + x[5] * x[5] + + x[10] * x[10] + ) + ) - 4 * x[5] + ) + 3 + ) + + (x[6] * x[6] + x[10] * x[10]) * + (x[6] * x[6] + x[10] * x[10]) + ) - 4 * x[6] + ) + 3 + ) + + (x[7] * x[7] + x[10] * x[10]) * + (x[7] * x[7] + x[10] * x[10]) + ) - 4 * x[7] + ) + 3 + ) + + (x[8] * x[8] + x[10] * x[10]) * + (x[8] * x[8] + x[10] * x[10]) + ) - 4 * x[8] + ) + 3 + ) + (x[9] * x[9] + x[10] * x[10]) * (x[9] * x[9] + x[10] * x[10]) + ) - 4 * x[9] + ) + 3 + ) + end + x0 = [0.0; fill(1.0, 9)] + H = DI.hessian(gams_objective, backend, x0) + + # Expected Hessian at x0: + # - H[1,1] = 4 (since x₁=0) + # - H[i,i] = 16 for i ∈ 2:9 (since xᵢ=1, x₁₀=1) + # - H[10,10] = 140 (sum of contributions from all 9 terms) + # - H[i,10] = H[10,i] = 8xᵢx₁₀ = 0 for i=1, 8 for i∈2:9 + H_expected = zeros(10, 10) + H_expected[1, 1] = 4.0 + for i in 2:9 + H_expected[i, i] = 16.0 + H_expected[i, 10] = 8.0 + H_expected[10, i] = 8.0 + end + H_expected[10, 10] = 140.0 + + @test H ≈ H_expected + end end