Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
Expand All @@ -29,7 +28,6 @@ LinearAlgebra = "1.11.0"
OrderedCollections = "1"
Setfield = "1.1.1"
SparseArrays = "1.11.0"
UnPack = "1.0.2"
julia = "1.9"

[extras]
Expand Down
253 changes: 133 additions & 120 deletions lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,136 @@ end

struct NotInferredTimeDomain end

struct InferEquationClosure
varsbuf::Set{SymbolicT}
# variables in each argument to an operator
arg_varsbuf::Set{SymbolicT}
# hyperedge for each equation
hyperedge::Set{ClockVertex.Type}
# hyperedge for each argument to an operator
arg_hyperedge::Set{ClockVertex.Type}
# mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
relative_hyperedges::Dict{Int, Set{ClockVertex.Type}}
var_to_idx::Dict{SymbolicT, Int}
inference_graph::HyperGraph{ClockVertex.Type}
end

function InferEquationClosure(var_to_idx, inference_graph)
InferEquationClosure(Set{SymbolicT}(), Set{SymbolicT}(), Set{ClockVertex.Type}(),
Set{ClockVertex.Type}(), Dict{Int, Set{ClockVertex.Type}}(),
var_to_idx, inference_graph)
end

function (iec::InferEquationClosure)(ieq::Int, eq::Equation, is_initialization_equation::Bool)
(; varsbuf, arg_varsbuf, hyperedge, arg_hyperedge, relative_hyperedges) = iec
(; var_to_idx, inference_graph) = iec
empty!(varsbuf)
empty!(hyperedge)
# get variables in equation
SU.search_variables!(varsbuf, eq; is_atomic = MTKBase.OperatorIsAtomic{SU.Operator}())
# add the equation to the hyperedge
eq_node = if is_initialization_equation
ClockVertex.InitEquation(ieq)
else
ClockVertex.Equation(ieq)
end
push!(hyperedge, eq_node)
for var in varsbuf
idx = get(var_to_idx, var, nothing)
# if this is just a single variable, add it to the hyperedge
if idx isa Int
push!(hyperedge, ClockVertex.Variable(idx))
# we don't immediately `continue` here because this variable might be a
# `Sample` or similar and we want the clock information from it if it is.
end
# now we only care about synchronous operators
op, args = @match var begin
BSImpl.Term(; f, args) && if is_timevarying_operator(f)::Bool end => (f, args)
_ => continue
end

# arguments and corresponding time domains
tdomains = input_timedomain(op)::Vector{InputTimeDomainElT}
nargs = length(args)
ndoms = length(tdomains)
if nargs != ndoms
throw(ArgumentError("""
Operator $op applied to $nargs arguments $args but only returns $ndoms \
domains $tdomains from `input_timedomain`.
"""))
end

# each relative clock mapping is only valid per operator application
empty!(relative_hyperedges)
for (arg, domain) in zip(args, tdomains)
empty!(arg_varsbuf)
empty!(arg_hyperedge)
# get variables in argument
SU.search_variables!(arg_varsbuf, arg; is_atomic = MTKBase.OperatorIsAtomic{Union{Differential, MTKBase.Shift}}())
# get hyperedge for involved variables
for v in arg_varsbuf
vidx = get(var_to_idx, v, nothing)
vidx === nothing && continue
push!(arg_hyperedge, ClockVertex.Variable(vidx))
end

@match domain begin
# If the time domain for this argument is a clock, then all variables in this edge have that clock.
x::SciMLBase.AbstractClock => begin
# add the clock to the edge
push!(arg_hyperedge, ClockVertex.Clock(x))
# add the edge to the graph
add_edge!(inference_graph, arg_hyperedge)
end
# We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
# involved variables have the same clock.
InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge)
# All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
# add the edge, and instead add this to the `relative_hyperedges` mapping.
InferredClock.InferredDiscrete(i) => begin
relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i)
union!(relative_edge, arg_hyperedge)
end
end
end

outdomain = output_timedomain(op)
@match outdomain begin
x::SciMLBase.AbstractClock => begin
push!(hyperedge, ClockVertex.Clock(x))
end
InferredClock.Inferred() => nothing
InferredClock.InferredDiscrete(i) => begin
buffer = get(relative_hyperedges, i, nothing)
if buffer !== nothing
union!(hyperedge, buffer)
delete!(relative_hyperedges, i)
end
end
end

for (_, relative_edge) in relative_hyperedges
add_edge!(inference_graph, relative_edge)
end
end

add_edge!(inference_graph, hyperedge)
end

"""
Update the equation-to-time domain mapping by inferring the time domain from the variables.
"""
function infer_clocks!(ci::ClockInference)
(; ts, eq_domain, init_eq_domain, var_domain, inferred, inference_graph) = ci
(; var_to_diff, graph) = ts.structure
sys = get_sys(ts)
fullvars = StateSelection.get_fullvars(ts)
isempty(inferred) && return ci

var_to_idx = Dict{SymbolicT, Int}(fullvars .=> eachindex(fullvars))
var_to_idx = Dict{SymbolicT, Int}()
sizehint!(var_to_idx, length(fullvars))
for (i, v) in enumerate(fullvars)
var_to_idx[v] = i
end

# all shifted variables have the same clock as the unshifted variant
for (i, v) in enumerate(fullvars)
Expand All @@ -81,112 +200,8 @@ function infer_clocks!(ci::ClockInference)
_ => nothing
end
end
infer_equation = InferEquationClosure(var_to_idx, inference_graph)

# preallocated buffers:
# variables in each equation
varsbuf = Set{SymbolicT}()
# variables in each argument to an operator
arg_varsbuf = Set{SymbolicT}()
# hyperedge for each equation
hyperedge = Set{ClockVertex.Type}()
# hyperedge for each argument to an operator
arg_hyperedge = Set{ClockVertex.Type}()
# mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
relative_hyperedges = Dict{Int, Set{ClockVertex.Type}}()

function infer_equation(ieq, eq, is_initialization_equation)
empty!(varsbuf)
empty!(hyperedge)
# get variables in equation
SU.search_variables!(varsbuf, eq; is_atomic = MTKBase.OperatorIsAtomic{SU.Operator}())
# add the equation to the hyperedge
eq_node = if is_initialization_equation
ClockVertex.InitEquation(ieq)
else
ClockVertex.Equation(ieq)
end
push!(hyperedge, eq_node)
for var in varsbuf
idx = get(var_to_idx, var, nothing)
# if this is just a single variable, add it to the hyperedge
if idx isa Int
push!(hyperedge, ClockVertex.Variable(idx))
# we don't immediately `continue` here because this variable might be a
# `Sample` or similar and we want the clock information from it if it is.
end
# now we only care about synchronous operators
op, args = @match var begin
BSImpl.Term(; f, args) && if is_timevarying_operator(f)::Bool end => (f, args)
_ => continue
end

# arguments and corresponding time domains
tdomains = input_timedomain(op)::Vector{InputTimeDomainElT}
nargs = length(args)
ndoms = length(tdomains)
if nargs != ndoms
throw(ArgumentError("""
Operator $op applied to $nargs arguments $args but only returns $ndoms \
domains $tdomains from `input_timedomain`.
"""))
end

# each relative clock mapping is only valid per operator application
empty!(relative_hyperedges)
for (arg, domain) in zip(args, tdomains)
empty!(arg_varsbuf)
empty!(arg_hyperedge)
# get variables in argument
SU.search_variables!(arg_varsbuf, arg; is_atomic = MTKBase.OperatorIsAtomic{Union{Differential, MTKBase.Shift}}())
# get hyperedge for involved variables
for v in arg_varsbuf
vidx = get(var_to_idx, v, nothing)
vidx === nothing && continue
push!(arg_hyperedge, ClockVertex.Variable(vidx))
end

@match domain begin
# If the time domain for this argument is a clock, then all variables in this edge have that clock.
x::SciMLBase.AbstractClock => begin
# add the clock to the edge
push!(arg_hyperedge, ClockVertex.Clock(x))
# add the edge to the graph
add_edge!(inference_graph, arg_hyperedge)
end
# We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
# involved variables have the same clock.
InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge)
# All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
# add the edge, and instead add this to the `relative_hyperedges` mapping.
InferredClock.InferredDiscrete(i) => begin
relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i)
union!(relative_edge, arg_hyperedge)
end
end
end

outdomain = output_timedomain(op)
@match outdomain begin
x::SciMLBase.AbstractClock => begin
push!(hyperedge, ClockVertex.Clock(x))
end
InferredClock.Inferred() => nothing
InferredClock.InferredDiscrete(i) => begin
buffer = get(relative_hyperedges, i, nothing)
if buffer !== nothing
union!(hyperedge, buffer)
delete!(relative_hyperedges, i)
end
end
end

for (_, relative_edge) in relative_hyperedges
add_edge!(inference_graph, relative_edge)
end
end

add_edge!(inference_graph, hyperedge)
end
for (ieq, eq) in enumerate(MTKBase.equations(sys))
infer_equation(ieq, eq, false)
end
Expand All @@ -212,7 +227,9 @@ function infer_clocks!(ci::ClockInference)
"""))
end

clock = partition[only(clockidxs)].:1
clock = Moshi.Match.@match partition[only(clockidxs)] begin
ClockVertex.Clock(clk) => clk
end
for vert in partition
Moshi.Match.@match vert begin
ClockVertex.Variable(i) => (var_domain[i] = clock)
Expand Down Expand Up @@ -275,19 +292,15 @@ function split_system(ci::ClockInference{S}) where {S}
# populates clock_to_id and id_to_clock
# checks if there is a continuous_id (for some reason? clock to id does this too)
for (i, d) in enumerate(eq_domain)
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
continuous_id = continuous_id

# Fill the clock_to_id dict as you go,
# ContinuousClock() => 1, ...
get!(clock_to_id, d) do
cid = (cid_counter[] += 1)
push!(id_to_clock, d)
if d == SciMLBase.ContinuousClock()
continuous_id[] = cid
end
cid
# We don't use `get!` here because that desperately wants to box things
cid = get(clock_to_id, d, 0)
if iszero(cid)
cid = (cid_counter[] += 1)
push!(id_to_clock, d)
if d === SciMLBase.ContinuousClock()
continuous_id[] = cid
end
clock_to_id[d] = cid
end
eq_to_cid[i] = cid
resize_or_push!(cid_to_eq, i, cid)
Expand Down
63 changes: 42 additions & 21 deletions lib/ModelingToolkitTearing/src/reassemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState,
dummy_sub = Dict{SymbolicT, SymbolicT}()

if MTKBase.has_iv(state.sys) && MTKBase.get_iv(state.sys) !== nothing
iv = MTKBase.get_iv(state.sys)
iv = MTKBase.get_iv(state.sys)::SymbolicT
if !StateSelection.is_only_discrete(state.structure)
D = Differential(iv)
else
Expand All @@ -1089,29 +1089,50 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState,
extra_unknowns = state.fullvars[extra_eqs_vars[2]]
if StateSelection.is_only_discrete(state.structure)
var_sccs = add_additional_history!(
state, var_eq_matching, full_var_eq_matching, var_sccs, iv)
state, var_eq_matching, full_var_eq_matching, var_sccs, iv::SymbolicT)
end

# Structural simplification
substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub, iv, D)

var_sccs = generate_derivative_variables!(
state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs, mm, iv)
neweqs, solved_eqs,
eq_ordering,
var_ordering,
nelim_eq,
nelim_var = generate_system_equations!(
state, neweqs, var_eq_matching, full_var_eq_matching,
var_sccs, extra_eqs_vars, iv, D; simplify, inline_linear_sccs,
analytical_linear_scc_limit)

state = reorder_vars!(
state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var)
# var_eq_matching and full_var_eq_matching are now invalidated

sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
extra_unknowns, iv, D; array_hack)
if iv isa SymbolicT # Without iv we don't have derivatives
D = D::Union{Differential, Shift}
substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub, iv, D)

var_sccs = generate_derivative_variables!(
state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs, mm, iv)
end
if iv isa SymbolicT
D = D::Union{Differential, Shift}
neweqs, solved_eqs,
eq_ordering,
var_ordering,
nelim_eq,
nelim_var = generate_system_equations!(
state, neweqs, var_eq_matching, full_var_eq_matching,
var_sccs, extra_eqs_vars, iv, D; simplify, inline_linear_sccs,
analytical_linear_scc_limit)
state = reorder_vars!(
state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var)
# var_eq_matching and full_var_eq_matching are now invalidated

sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
extra_unknowns, iv, D; array_hack)
else
D = D::Nothing
neweqs, solved_eqs,
eq_ordering,
var_ordering,
nelim_eq,
nelim_var = generate_system_equations!(
state, neweqs, var_eq_matching, full_var_eq_matching,
var_sccs, extra_eqs_vars, iv, D; simplify, inline_linear_sccs,
analytical_linear_scc_limit)
state = reorder_vars!(
state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var)
# var_eq_matching and full_var_eq_matching are now invalidated

sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
extra_unknowns, iv, D; array_hack)
end

@set! state.sys = sys
@set! sys.tearing_state = state
Expand Down
Loading
Loading