diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 3a181cc58..74e42dfea 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -325,7 +325,7 @@ function __dual_init( ∂_A = partial_vals(A) ∂_b = partial_vals(b) - primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0) + primal_prob = LinearProblem{SciMLBase.isinplace(prob)}(new_A, new_b; u0 = new_u0) if get_dual_type(prob.A) !== nothing dual_type = get_dual_type(prob.A) diff --git a/test/adjoint.jl b/test/adjoint.jl index 42186c376..a8f790628 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -141,3 +141,43 @@ for alg in ( @test zyg_jac ≈ fd_jac rtol = 1.0e-4 end + +struct System end +function update_A!(A, p::Vector{T}) where {T} + A[1, 1] = sum(p) + A[1, 2] = one(T) + A[2, 1] = 2one(T) + return A[2, 2] = 3one(T) +end +function update_b!(b, p::Vector{T}) where {T} + b[1] = p[1] + return b[2] = 2one(T) +end + +function SciMLBase.get_new_A_b(::System, f, p, A, b; kw...) + if eltype(A) != eltype(p) + A = similar(A, eltype(p)) + b = similar(b, eltype(p)) + end + f.update_A!(A, p) + f.update_b!(b, p) + return A, b +end + + +@testset "Avoid stackoverflow with `SciMLBase.get_new_A_b` hook" begin + # See https://github.com/SciML/LinearSolve.jl/pull/868#issuecomment-3723338914 + p = [1.0, 2.0, 3.0] + A = [6.0 1.0; 2.0 3.0] + b = [1.0, 2.0] + linfun = SciMLBase.SymbolicLinearInterface(update_A!, update_b!, System(), nothing, nothing) + + prob = LinearProblem(A, b; f = linfun) + sol = solve(prob) + + newp = [2.0, 3.0, 4.0] + @test_nowarn ForwardDiff.gradient(newp) do p + prob2 = remake(prob; p) + sum(solve(prob2).u) + end +end