Skip to content

Commit e7a25d7

Browse files
committed
fix parameter substitution
take the MTKParameters values into account and substitute the tunable parameters with the appropriate symbolic expression
1 parent 668bd84 commit e7a25d7

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

src/systems/optimal_control_interface.jl

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ function process_DynamicOptProblem(
234234
warn_overdetermined(sys, op)
235235
ctrls = unbound_inputs(sys)
236236
states = unknowns(sys)
237-
params = tune_parameters ? tunable_parameters(sys) : []
237+
tunable_params = tune_parameters ? tunable_parameters(sys) : []
238238

239239
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
240240
op = Dict([default_toterm(value(k)) => v for (k, v) in op])
@@ -248,9 +248,16 @@ function process_DynamicOptProblem(
248248
model_tspan, steps, is_free_t = process_tspan(tspan, dt, steps)
249249
warn_overdetermined(sys, op)
250250

251-
pmap = filter(p -> (first(p) Set(unknowns(sys))), op)
252-
pmap = recursive_unwrap(AnyDict(pmap))
253-
evaluate_varmap!(pmap, keys(pmap))
251+
# Build pmap for symbolic substitution in costs/constraints/bounds
252+
all_parameters = default_toterm.(parameters(sys))
253+
# Extract all parameter values from processed p (which has defaults filled in)
254+
getter = SymbolicIndexingInterface.getp(sys, all_parameters)
255+
pmap = AnyDict(all_parameters .=> getter(p))
256+
257+
# Filter out tunable parameters - they should remain symbolic
258+
tunable_set = Set(default_toterm.(tunable_params))
259+
pmap = filter(kvp -> first(kvp) tunable_set, pmap)
260+
254261
c0 = value.([pmap[c] for c in ctrls])
255262
p0, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
256263

@@ -259,10 +266,16 @@ function process_DynamicOptProblem(
259266
generate_time_variable!(model, model_tspan, tsteps)
260267
U = generate_state_variable!(model, u0, length(states), tsteps)
261268
V = generate_input_variable!(model, c0, length(ctrls), tsteps)
262-
P = generate_tunable_params!(model, p0, length(params))
269+
P = generate_tunable_params!(model, p0, length(tunable_params))
263270
tₛ = generate_timescale!(model, get(pmap, tspan[2], tspan[2]), is_free_t)
264271
fullmodel = model_type(model, U, V, P, tₛ, is_free_t)
265272

273+
# Add the symbolic representation of the tunable parameters to the map
274+
# The order of the Tunable section is given by the tunable_parameters(sys)
275+
# Some backends need symbolic accessors instead of raw variables
276+
P_syms = [get_param_for_pmap(fullmodel, P, i) for i in eachindex(tunable_params)]
277+
merge!(pmap, Dict(tunable_params .=> P_syms))
278+
266279
set_variable_bounds!(fullmodel, sys, pmap, tspan[2])
267280
add_cost_function!(fullmodel, sys, tspan, pmap)
268281
add_user_constraints!(fullmodel, sys, tspan, pmap)
@@ -279,6 +292,22 @@ function generate_tunable_params! end
279292
function generate_timescale! end
280293
function add_initial_constraints! end
281294
function add_constraint! end
295+
# Default: return P[i] directly. Symbolic backends (like Pyomo) can override this.
296+
get_param_for_pmap(model, P, i) = P isa AbstractArray ? P[i] : P
297+
298+
function f_wrapper(f, Uₙ, Vₙ, p, P, t)
299+
if isempty(P)
300+
# no tunable parameters
301+
return f(Uₙ, Vₙ, p, t)
302+
end
303+
if SciMLStructures.isscimlstructure(p)
304+
_, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
305+
p′ = repack(P)
306+
f(Uₙ, Vₙ, p′, t)
307+
else
308+
f(Uₙ, Vₙ, P, t)
309+
end
310+
end
282311

283312
function set_variable_bounds!(m, sys, pmap, tf)
284313
@unpack model, U, V, tₛ = m
@@ -394,6 +423,7 @@ function process_integral_bounds end
394423
function lowered_integral end
395424
function lowered_derivative end
396425
function lowered_var end
426+
function lowered_param end
397427
function fixed_t_map end
398428

399429
function add_user_constraints!(model, sys, tspan, pmap)
@@ -445,8 +475,8 @@ function substitute_toterm(vars, exprs)
445475
exprs = map(c -> Symbolics.fast_substitute(c, toterm_map), exprs)
446476
end
447477

448-
function substitute_params(pmap, exprs)
449-
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
478+
function substitute_params(pmap::Dict, exprs)
479+
exprs = map(c -> Symbolics.fixpoint_sub(c, pmap), exprs)
450480
end
451481

452482
function check_constraint_vars(vars)

0 commit comments

Comments
 (0)