diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1d34086329..7dc92a393f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -46,6 +46,7 @@ jobs: 'rules/tasks', 'rules/twice_precision', 'rules/performance_patches', + 'rules/high_order_derivative_patches', ] version: - 'lts' @@ -87,6 +88,7 @@ jobs: matrix: test_group: [ {test_type: 'ext', label: 'differentiation_interface'}, + {test_type: 'ext', label: 'differentiation_interface_second_order'}, {test_type: 'ext', label: 'dynamic_expressions'}, {test_type: 'ext', label: 'flux'}, {test_type: 'ext', label: 'function_wrappers'}, diff --git a/Project.toml b/Project.toml index 2e7ac651cb..419911d896 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.192" +version = "0.4.193" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 75aeef6e0c..b9c863a0e9 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -128,6 +128,7 @@ end include("tools_for_rules.jl") @unstable include("test_utils.jl") @unstable include("test_resources.jl") +include("interface.jl") include(joinpath("rules", "avoiding_non_differentiable_code.jl")) include(joinpath("rules", "blas.jl")) @@ -151,10 +152,10 @@ else include(joinpath("rules", "array_legacy.jl")) end -# Including this in DispatchDoctor causes precompilation error. +# Including this in DispatchDoctor causes precompilation error. @unstable include(joinpath("rules", "performance_patches.jl")) +include(joinpath("rules", "high_order_derivative_patches.jl")) -include("interface.jl") include("config.jl") include("developer_tools.jl") diff --git a/src/interpreter/forward_mode.jl b/src/interpreter/forward_mode.jl index c0ce3ae9f5..f9c6f9de14 100644 --- a/src/interpreter/forward_mode.jl +++ b/src/interpreter/forward_mode.jl @@ -1,3 +1,31 @@ +# Check if a type contains Union{} (bottom type) anywhere in its structure. +# This can happen with unreachable code or failed type inference. +@inline contains_bottom_type(T) = _contains_bottom_type(T, Base.IdSet{Any}()) + +function _contains_bottom_type(T, seen::Base.IdSet{Any}) + T === Union{} && return true + if T isa Union + return _contains_bottom_type(T.a, seen) || _contains_bottom_type(T.b, seen) + elseif T isa TypeVar + T in seen && return false + push!(seen, T) + return _contains_bottom_type(T.ub, seen) + elseif T isa UnionAll + T in seen && return false + push!(seen, T) + return _contains_bottom_type(T.body, seen) + elseif T isa DataType + T in seen && return false + push!(seen, T) + for p in T.parameters + _contains_bottom_type(p, seen) && return true + end + return false + else + return false + end +end + function build_frule(args...; debug_mode=false, silence_debug_messages=true) sig = _typeof(TestUtils.__get_primals(args)) interp = get_interpreter(ForwardMode) @@ -16,19 +44,27 @@ end sig_or_mi; debug_mode=false, silence_debug_messages=true, + skip_world_age_check=false, ) where {C} Returns a function which performs forward-mode AD for `sig_or_mi`. Will derive a rule if `sig_or_mi` is not a primitive. + +Set `skip_world_age_check=true` when the interpreter's world age is intentionally older +than the current world (e.g., when building rules for MistyClosure which uses its own world). """ function build_frule( - interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true + interp::MooncakeInterpreter{C}, + sig_or_mi; + debug_mode=false, + silence_debug_messages=true, + skip_world_age_check=false, ) where {C} @nospecialize sig_or_mi # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater # than the current world age. - if Base.get_world_counter() > interp.world + if !skip_world_age_check && Base.get_world_counter() > interp.world throw( ArgumentError( "World age associated to interp is behind current world age. Please " * @@ -331,7 +367,11 @@ function modify_fwd_ad_stmts!( if isexpr(stmt, :invoke) || isexpr(stmt, :call) raw_args = isexpr(stmt, :invoke) ? stmt.args[2:end] : stmt.args sig_types = map(raw_args) do x - return CC.widenconst(get_forward_primal_type(info.primal_ir, x)) + t = CC.widenconst(get_forward_primal_type(info.primal_ir, x)) + # Replace types containing Union{} (unreachable code/failed inference) + # with Any. This allows the code to proceed; is_primitive will return + # false and we'll use dynamic rules that resolve types at runtime. + return contains_bottom_type(t) ? Any : t end sig = Tuple{sig_types...} mi = isexpr(stmt, :invoke) ? get_mi(stmt.args[1]) : missing diff --git a/src/rules/high_order_derivative_patches.jl b/src/rules/high_order_derivative_patches.jl new file mode 100644 index 0000000000..08add2507f --- /dev/null +++ b/src/rules/high_order_derivative_patches.jl @@ -0,0 +1,205 @@ +# Forward-mode primitive for _build_rule! on LazyDerivedRule. +# This avoids differentiating through get_interpreter which has a ccall to jl_get_world_counter. +# The tangent propagation happens through the fwds_oc MistyClosure call, not the rule building. +# Reverse-over-reverse is not supported; an rrule!! that throws is provided below. +@is_primitive MinimalCtx Tuple{typeof(_build_rule!),LazyDerivedRule,Tuple} + +function frule!!( + ::Dual{typeof(_build_rule!)}, + lazy_rule_dual::Dual{<:LazyDerivedRule{sig}}, + args_dual::Dual{<:Tuple}, +) where {sig} + lazy_rule = primal(lazy_rule_dual) + lazy_tangent = tangent(lazy_rule_dual) + primal_args = primal(args_dual) + tangent_args = tangent(args_dual) + + # Build rrule if not built (primal operation, no differentiation needed) + if !isdefined(lazy_rule, :rule) + interp = get_interpreter(ReverseMode) + lazy_rule.rule = build_rrule(interp, lazy_rule.mi; debug_mode=lazy_rule.debug_mode) + end + derived_rule = lazy_rule.rule + + # Initialize the tangent of the derived rule if needed + rule_tangent_field = lazy_tangent.fields.rule + if !isdefined(rule_tangent_field, :tangent) + # Need to update the MutableTangent's fields with a new PossiblyUninitTangent + new_rule_tangent = PossiblyUninitTangent(zero_tangent(derived_rule)) + lazy_tangent.fields = merge(lazy_tangent.fields, (; rule=new_rule_tangent)) + rule_tangent_field = new_rule_tangent + end + derived_tangent = rule_tangent_field.tangent + + # Forward-differentiate through the DerivedRule call. + # DerivedRule(args...) internally calls fwds_oc(args...) and returns (CoDual, Pullback) + fwds_oc = derived_rule.fwds_oc + fwds_oc_tangent = derived_tangent.fields.fwds_oc + + # Handle varargs unflattening + isva = _isva(derived_rule) + nargs = derived_rule.nargs + N = length(primal_args) + uf_primal_args = __unflatten_codual_varargs(isva, primal_args, nargs) + uf_tangent_args = __unflatten_tangent_varargs(isva, tangent_args, nargs) + + # Create dual args for frule!! call + dual_args = map(Dual, uf_primal_args, uf_tangent_args) + + # Call frule!! on fwds_oc to get forward-differentiated result + dual_fwds_oc = Dual(fwds_oc, fwds_oc_tangent) + codual_result_dual = frule!!(dual_fwds_oc, dual_args...) + + # Create Pullback and its tangent + pb_oc_ref = derived_rule.pb_oc_ref + pb_primal = Pullback(sig, pb_oc_ref, isva, N) + pb_tangent = Tangent((; pb_oc=derived_tangent.fields.pb_oc_ref)) + + # Return Dual of (CoDual, Pullback) + primal_result = (primal(codual_result_dual), pb_primal) + tangent_result = (tangent(codual_result_dual), pb_tangent) + return Dual(primal_result, tangent_result) +end + +# Helper to unflatten tangent args similar to __unflatten_codual_varargs +function __unflatten_tangent_varargs(isva::Bool, tangent_args, ::Val{nargs}) where {nargs} + isva || return tangent_args + group_tangent = tangent_args[nargs:end] + return (tangent_args[1:(nargs - 1)]..., group_tangent) +end + +# Reverse-over-reverse is not supported. Throw an informative error. +function rrule!!( + ::CoDual{typeof(_build_rule!)}, ::CoDual{<:LazyDerivedRule}, ::CoDual{<:Tuple} +) + throw( + ArgumentError( + "Reverse-over-reverse differentiation is not supported. " * + "Encountered attempt to differentiate _build_rule! in reverse mode.", + ), + ) +end + +# TODO: This is a workaround for forward-over-reverse. Primitives in reverse mode can get +# inlined when building the forward rule, exposing internal ccalls that lack an frule!!. +# For example, `dataids` is a reverse-mode primitive, but inlining it exposes +# `jl_genericmemory_owner`. The proper fix is to prevent primitive inlining during +# forward-over-reverse by forwarding `inlining_policy` through `BugPatchInterpreter` to +# `MooncakeInterpreter` during `optimise_ir!`, but this causes allocation regressions. +# See https://github.com/chalk-lab/Mooncake.jl/pull/878 for details. +@static if VERSION >= v"1.11-" + function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{:jl_genericmemory_owner}}, + ::Dual{Val{Any}}, + ::Dual{Tuple{Val{Any}}}, + ::Dual{Val{0}}, + ::Dual{Val{:ccall}}, + a::Dual{<:Memory}, + ) + return zero_dual(ccall(:jl_genericmemory_owner, Any, (Any,), primal(a))) + end + function rrule!!( + ::CoDual{typeof(_foreigncall_)}, + ::CoDual{Val{:jl_genericmemory_owner}}, + ::CoDual{Val{Any}}, + ::CoDual{Tuple{Val{Any}}}, + ::CoDual{Val{0}}, + ::CoDual{Val{:ccall}}, + a::CoDual{<:Memory}, + ) + y = zero_fcodual(ccall(:jl_genericmemory_owner, Any, (Any,), primal(a))) + return y, NoPullback(ntuple(_ -> NoRData(), 7)) + end +end + +# Avoid differentiating through AD infrastructure during second-order differentiation. +@zero_derivative MinimalCtx Tuple{ + typeof(Core.kwcall),NamedTuple,typeof(prepare_gradient_cache),Vararg +} +@zero_derivative MinimalCtx Tuple{ + typeof(Core.kwcall),NamedTuple,typeof(prepare_derivative_cache),Vararg +} +@zero_derivative MinimalCtx Tuple{ + typeof(Core.kwcall),NamedTuple,typeof(prepare_pullback_cache),Vararg +} +@zero_derivative MinimalCtx Tuple{typeof(zero_tangent),Any} + +@static if VERSION < v"1.11-" + @generated function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{:jl_alloc_array_1d}}, + ::Dual{Val{Vector{P}}}, + ::Dual{Tuple{Val{Any},Val{Int}}}, + ::Dual{Val{0}}, + ::Dual{Val{:ccall}}, + ::Dual{Type{Vector{P}}}, + n::Dual{Int}, + args::Vararg{Dual}, + ) where {P} + T = tangent_type(P) + return quote + _n = primal(n) + y = ccall(:jl_alloc_array_1d, Vector{$P}, (Any, Int), Vector{$P}, _n) + dy = ccall(:jl_alloc_array_1d, Vector{$T}, (Any, Int), Vector{$T}, _n) + return Dual(y, dy) + end + end + @generated function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{:jl_alloc_array_2d}}, + ::Dual{Val{Matrix{P}}}, + ::Dual{Tuple{Val{Any},Val{Int},Val{Int}}}, + ::Dual{Val{0}}, + ::Dual{Val{:ccall}}, + ::Dual{Type{Matrix{P}}}, + m::Dual{Int}, + n::Dual{Int}, + args::Vararg{Dual}, + ) where {P} + T = tangent_type(P) + return quote + _m, _n = primal(m), primal(n) + y = ccall(:jl_alloc_array_2d, Matrix{$P}, (Any, Int, Int), Matrix{$P}, _m, _n) + dy = ccall(:jl_alloc_array_2d, Matrix{$T}, (Any, Int, Int), Matrix{$T}, _m, _n) + return Dual(y, dy) + end + end + @generated function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{:jl_alloc_array_3d}}, + ::Dual{Val{Array{P,3}}}, + ::Dual{Tuple{Val{Any},Val{Int},Val{Int},Val{Int}}}, + ::Dual{Val{0}}, + ::Dual{Val{:ccall}}, + ::Dual{Type{Array{P,3}}}, + l::Dual{Int}, + m::Dual{Int}, + n::Dual{Int}, + args::Vararg{Dual}, + ) where {P} + T = tangent_type(P) + return quote + _l, _m, _n = primal(l), primal(m), primal(n) + y = ccall( + :jl_alloc_array_3d, + Array{$P,3}, + (Any, Int, Int, Int), + Array{$P,3}, + _l, + _m, + _n, + ) + dy = ccall( + :jl_alloc_array_3d, + Array{$T,3}, + (Any, Int, Int, Int), + Array{$T,3}, + _l, + _m, + _n, + ) + return Dual(y, dy) + end + end +end diff --git a/src/rules/low_level_maths.jl b/src/rules/low_level_maths.jl index b5536a7534..ea86e71588 100644 --- a/src/rules/low_level_maths.jl +++ b/src/rules/low_level_maths.jl @@ -71,6 +71,7 @@ @from_chainrules MinimalCtx Tuple{typeof(deg2rad),IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(rad2deg),IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(^),P,P} where {P<:IEEEFloat} + @from_chainrules MinimalCtx Tuple{typeof(atan),P,P} where {P<:IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(max),P,P} where {P<:IEEEFloat} @from_chainrules MinimalCtx Tuple{typeof(min),P,P} where {P<:IEEEFloat} diff --git a/src/rules/memory.jl b/src/rules/memory.jl index 7329e89bc4..7657b6a206 100644 --- a/src/rules/memory.jl +++ b/src/rules/memory.jl @@ -592,6 +592,24 @@ end # _new_ and _new_-adjacent rules for Memory, MemoryRef, and Array. +@static if VERSION >= v"1.12-" + @is_primitive MinimalCtx Tuple{typeof(Core.memorynew),Type{<:Memory},Int} + function frule!!( + ::Dual{typeof(Core.memorynew)}, ::Dual{Type{Memory{P}}}, n::Dual{Int} + ) where {P} + x = Core.memorynew(Memory{P}, primal(n)) + dx = Core.memorynew(Memory{tangent_type(P)}, primal(n)) + return Dual(x, dx) + end + function rrule!!( + ::CoDual{typeof(Core.memorynew)}, ::CoDual{Type{Memory{P}}}, n::CoDual{Int} + ) where {P} + x = Core.memorynew(Memory{P}, primal(n)) + dx = Core.memorynew(Memory{tangent_type(P)}, primal(n)) + return CoDual(x, dx), NoPullback((NoRData(), NoRData(), NoRData())) + end +end + @is_primitive MinimalCtx Tuple{Type{<:Memory},UndefInitializer,Int} function frule!!(::Dual{Type{Memory{P}}}, ::Dual{UndefInitializer}, n::Dual{Int}) where {P} x = Memory{P}(undef, primal(n)) @@ -908,6 +926,17 @@ function hand_written_rule_test_cases(rng_ctor, ::Val{:memory}) zip(mem_refs, sample_mem_ref_values), ) test_cases = vcat( + @static( + if VERSION >= v"1.12-" + [ + (true, :stability, nothing, Core.memorynew, Memory{Float64}, 5), + (true, :stability, nothing, Core.memorynew, Memory{Float64}, 10), + (true, :stability, nothing, Core.memorynew, Memory{Int}, 5), + ] + else + [] + end + ), # Rules for `Memory` (true, :stability, nothing, Memory{Float64}, undef, 5), diff --git a/src/rules/misty_closures.jl b/src/rules/misty_closures.jl index 22f415eff0..20a3b9e004 100644 --- a/src/rules/misty_closures.jl +++ b/src/rules/misty_closures.jl @@ -17,7 +17,25 @@ struct MistyClosureTangent dual_callable::Any end -_dual_mc(p::MistyClosure) = build_frule(get_interpreter(ForwardMode), p) +# Build a forward-mode rule for a MistyClosure using its original world age. +# +# We cannot use the current world age because the MistyClosure's IR (p.ir[]) has a +# valid_worlds range set at creation time. On Julia 1.12+, generate_dual_ir calls +# set_valid_world!(ir, interp.world), which throws if the world is outside this range. +# If methods were defined after the MistyClosure was created, the current world would +# fall outside valid_worlds and cause an error. +# +# Using the original world age is safe because lookup_ir for MistyClosure returns mc.ir[] +# directly, bypassing method table lookups. Nested non-primitive calls use LazyFRule or +# DynamicFRule, which obtain a current-world interpreter via get_interpreter() at runtime. +# We pass skip_world_age_check=true since build_frule's safety check would incorrectly +# reject our intentionally-older interpreter. +# +function _dual_mc(p::MistyClosure) + mc_world = UInt(p.oc.world) + interp = MooncakeInterpreter(DefaultCtx, ForwardMode; world=mc_world) + return build_frule(interp, p; skip_world_age_check=true) +end tangent_type(::Type{<:MistyClosure}) = MistyClosureTangent @@ -61,13 +79,27 @@ function _scale_internal(c::MaybeCache, a::Float64, t::T) where {T<:MistyClosure return T(captures_tangent, t.dual_callable) end -import .TestUtils: populate_address_map_internal, AddressMap +import .TestUtils: populate_address_map_internal, AddressMap, has_equal_data_internal function populate_address_map_internal( m::AddressMap, p::MistyClosure, t::MistyClosureTangent ) return populate_address_map_internal(m, p.oc.captures, t.captures_tangent) end +function has_equal_data_internal( + x::MistyClosureTangent, + y::MistyClosureTangent, + equal_undefs::Bool, + d::Dict{Tuple{UInt,UInt},Bool}, +) + # Only compare captures_tangent. The dual_callable field is a forward-mode rule + # built on-demand by _dual_mc, which creates a new interpreter each time. Different + # interpreter instances produce different rule objects, even for the same MistyClosure. + # Since dual_callable is just a computational tool (not part of the tangent's value), + # two tangents with identical captures_tangent are mathematically equal. + return has_equal_data_internal(x.captures_tangent, y.captures_tangent, equal_undefs, d) +end + struct MistyClosureFData captures_fdata::Any dual_callable::Any diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index ea74e84b6d..42ceb0cb96 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -5,6 +5,7 @@ Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using DifferentiationInterface, DifferentiationInterfaceTest using Mooncake: Mooncake +# Test first-order differentiation (reverse mode) test_differentiation( [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; excluded=SECOND_ORDER, diff --git a/test/ext/differentiation_interface_second_order/Project.toml b/test/ext/differentiation_interface_second_order/Project.toml new file mode 100644 index 0000000000..7639dd345f --- /dev/null +++ b/test/ext/differentiation_interface_second_order/Project.toml @@ -0,0 +1,5 @@ +[deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/ext/differentiation_interface_second_order/differentiation_interface_second_order.jl b/test/ext/differentiation_interface_second_order/differentiation_interface_second_order.jl new file mode 100644 index 0000000000..6092c4eb04 --- /dev/null +++ b/test/ext/differentiation_interface_second_order/differentiation_interface_second_order.jl @@ -0,0 +1,13 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) + +using DifferentiationInterface, DifferentiationInterfaceTest +using Mooncake: Mooncake + +# Test second-order differentiation (forward-over-reverse) +test_differentiation( + [SecondOrder(AutoMooncakeForward(; config=nothing), AutoMooncake(; config=nothing))]; + excluded=[FIRST_ORDER..., :hvp, :second_derivative], # testing only :hessian + logging=true, +) diff --git a/test/front_matter.jl b/test/front_matter.jl index 5dc7363673..dc7fc0f140 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -28,11 +28,13 @@ using Mooncake: _add_to_primal, _diff, _dot, + Dual, zero_dual, zero_codual, codual_type, rrule!!, build_rrule, + build_frule, value_and_gradient!!, value_and_pullback!!, NoFData, @@ -44,7 +46,8 @@ using Mooncake: get_interpreter, Mode, ForwardMode, - ReverseMode + ReverseMode, + MistyClosureTangent using Mooncake: CC, diff --git a/test/rules/high_order_derivative_patches.jl b/test/rules/high_order_derivative_patches.jl new file mode 100644 index 0000000000..b638dcb031 --- /dev/null +++ b/test/rules/high_order_derivative_patches.jl @@ -0,0 +1,118 @@ +function _compute_grad(rule, f, x::Vector{Float64}, x_fdata::Vector{Float64}) + fill!(x_fdata, 0.0) + _, pb!! = rule(zero_fcodual(f), CoDual(x, x_fdata)) + pb!!(1.0) + return copy(x_fdata) +end + +function _hessian_column(f, x::Vector{Float64}, i::Int) + x_fdata = fdata(zero_tangent(x)) + rule = build_rrule(f, x) + frule = build_frule(_compute_grad, rule, f, x, x_fdata) + + x_tangent = zeros(length(x)) + x_tangent[i] = 1.0 + fill!(x_fdata, 0.0) + + result = frule( + zero_dual(_compute_grad), + zero_dual(rule), + zero_dual(f), + Dual(x, x_tangent), + Dual(x_fdata, zeros(length(x))), + ) + return primal(result), tangent(result) +end + +function _compute_hessian(f, x::Vector{Float64}) + n = length(x) + H = zeros(n, n) + for i in 1:n + _, H[:, i] = _hessian_column(f, x, i) + end + return H +end + +@testset "hessian_scalar_functions" begin + @testset "sum" begin + g(x) = sum(x) + x = [2.0] + grad, hess_col = _hessian_column(g, x, 1) + @test grad ≈ [1.0] + @test hess_col ≈ [0.0] + end + + @testset "x^4.0" begin + f(x) = x[1]^4.0 + x = [2.0] + grad, hess_col = _hessian_column(f, x, 1) + @test grad ≈ [32.0] + @test hess_col ≈ [48.0] + end + + @testset "x^4" begin + f(x) = x[1]^4 + x = [2.0] + grad, hess_col = _hessian_column(f, x, 1) + @test grad ≈ [32.0] + @test hess_col ≈ [48.0] + end + + @testset "x^6" begin + f(x) = x[1]^6 + x = [2.0] + grad, hess_col = _hessian_column(f, x, 1) + @test grad ≈ [192.0] + @test hess_col ≈ [480.0] + end +end + +@testset "hessian_multivariate" begin + @testset "Rosenbrock" begin + rosen(z) = (1.0 - z[1])^2 + 100.0 * (z[2] - z[1]^2)^2 + z = [1.2, 1.2] + H = _compute_hessian(rosen, z) + expected_H = [1250.0 -480.0; -480.0 200.0] + @test H ≈ expected_H rtol = 1e-10 + end + + @testset "sum of squares" begin + f(x) = sum([x[1] * x[1], x[2] * x[2]]) + x = [2.0, 3.0] + grad, hess_col = _hessian_column(f, x, 1) + @test grad ≈ [4.0, 6.0] rtol = 1e-10 + @test hess_col ≈ [2.0, 0.0] rtol = 1e-10 + end + + @testset "broadcast sum of squares" begin + # Tests broadcast operations: x .* x uses broadcasting + f(x) = sum(x .* x) + x = [2.0, 3.0] + H = _compute_hessian(f, x) + # f(x) = x₁² + x₂², so ∇f = [2x₁, 2x₂] and H = 2I + @test H ≈ [2.0 0.0; 0.0 2.0] rtol = 1e-10 + end + + @testset "GAMS objective" begin + function gams_objective(x) + #! format: off + objvar = (((((((((((((((((((((((((((x[1] * x[1] + x[10] * x[10]) * (x[1] * x[1] + x[10] * x[10]) - 4 * x[1]) + 3) + (x[2] * x[2] + x[10] * x[10]) * (x[2] * x[2] + x[10] * x[10])) - 4 * x[2]) + 3) + (x[3] * x[3] + x[10] * x[10]) * (x[3] * x[3] + x[10] * x[10])) - 4 * x[3]) + 3) + (x[4] * x[4] + x[10] * x[10]) * (x[4] * x[4] + x[10] * x[10])) - 4 * x[4]) + 3) + (x[5] * x[5] + x[10] * x[10]) * (x[5] * x[5] + x[10] * x[10])) - 4 * x[5]) + 3) + (x[6] * x[6] + x[10] * x[10]) * (x[6] * x[6] + x[10] * x[10])) - 4 * x[6]) + 3) + (x[7] * x[7] + x[10] * x[10]) * (x[7] * x[7] + x[10] * x[10])) - 4 * x[7]) + 3) + (x[8] * x[8] + x[10] * x[10]) * (x[8] * x[8] + x[10] * x[10])) - 4 * x[8]) + 3) + (x[9] * x[9] + x[10] * x[10]) * (x[9] * x[9] + x[10] * x[10])) - 4 * x[9]) + 3) - 0 + #! format: on + return objvar + end + + x0 = [0.0; fill(1.0, 9)] + H = _compute_hessian(gams_objective, x0) + + H_expected = zeros(10, 10) + H_expected[1, 1] = 4.0 + for i in 2:9 + H_expected[i, i] = 16.0 + H_expected[i, 10] = 8.0 + H_expected[10, i] = 8.0 + end + H_expected[10, 10] = 140.0 + + @test H ≈ H_expected rtol = 1e-10 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 90911f3587..e5a820a3a1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,6 +66,8 @@ include("front_matter.jl") include(joinpath("rules", "performance_patches.jl")) elseif test_group == "rules/dispatch_doctor" include(joinpath("rules", "dispatch_doctor.jl")) + elseif test_group == "rules/high_order_derivative_patches" + include(joinpath("rules", "high_order_derivative_patches.jl")) else throw(error("test_group=$(test_group) is not recognised")) end