diff --git a/Project.toml b/Project.toml index 2bd55b6..8a01db9 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" @@ -29,7 +28,6 @@ LinearAlgebra = "1.11.0" OrderedCollections = "1" Setfield = "1.1.1" SparseArrays = "1.11.0" -UnPack = "1.0.2" julia = "1.9" [extras] diff --git a/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl b/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl index 450ffc8..5a802b9 100644 --- a/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl +++ b/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl @@ -58,17 +58,136 @@ end struct NotInferredTimeDomain end +struct InferEquationClosure + varsbuf::Set{SymbolicT} + # variables in each argument to an operator + arg_varsbuf::Set{SymbolicT} + # hyperedge for each equation + hyperedge::Set{ClockVertex.Type} + # hyperedge for each argument to an operator + arg_hyperedge::Set{ClockVertex.Type} + # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition + relative_hyperedges::Dict{Int, Set{ClockVertex.Type}} + var_to_idx::Dict{SymbolicT, Int} + inference_graph::HyperGraph{ClockVertex.Type} +end + +function InferEquationClosure(var_to_idx, inference_graph) + InferEquationClosure(Set{SymbolicT}(), Set{SymbolicT}(), Set{ClockVertex.Type}(), + Set{ClockVertex.Type}(), Dict{Int, Set{ClockVertex.Type}}(), + var_to_idx, inference_graph) +end + +function (iec::InferEquationClosure)(ieq::Int, eq::Equation, is_initialization_equation::Bool) + (; varsbuf, arg_varsbuf, hyperedge, arg_hyperedge, relative_hyperedges) = iec + (; var_to_idx, inference_graph) = iec + empty!(varsbuf) + empty!(hyperedge) + # get variables in equation + SU.search_variables!(varsbuf, eq; is_atomic = MTKBase.OperatorIsAtomic{SU.Operator}()) + # add the equation to the hyperedge + eq_node = if is_initialization_equation + ClockVertex.InitEquation(ieq) + else + ClockVertex.Equation(ieq) + end + push!(hyperedge, eq_node) + for var in varsbuf + idx = get(var_to_idx, var, nothing) + # if this is just a single variable, add it to the hyperedge + if idx isa Int + push!(hyperedge, ClockVertex.Variable(idx)) + # we don't immediately `continue` here because this variable might be a + # `Sample` or similar and we want the clock information from it if it is. + end + # now we only care about synchronous operators + op, args = @match var begin + BSImpl.Term(; f, args) && if is_timevarying_operator(f)::Bool end => (f, args) + _ => continue + end + + # arguments and corresponding time domains + tdomains = input_timedomain(op)::Vector{InputTimeDomainElT} + nargs = length(args) + ndoms = length(tdomains) + if nargs != ndoms + throw(ArgumentError(""" + Operator $op applied to $nargs arguments $args but only returns $ndoms \ + domains $tdomains from `input_timedomain`. + """)) + end + + # each relative clock mapping is only valid per operator application + empty!(relative_hyperedges) + for (arg, domain) in zip(args, tdomains) + empty!(arg_varsbuf) + empty!(arg_hyperedge) + # get variables in argument + SU.search_variables!(arg_varsbuf, arg; is_atomic = MTKBase.OperatorIsAtomic{Union{Differential, MTKBase.Shift}}()) + # get hyperedge for involved variables + for v in arg_varsbuf + vidx = get(var_to_idx, v, nothing) + vidx === nothing && continue + push!(arg_hyperedge, ClockVertex.Variable(vidx)) + end + + @match domain begin + # If the time domain for this argument is a clock, then all variables in this edge have that clock. + x::SciMLBase.AbstractClock => begin + # add the clock to the edge + push!(arg_hyperedge, ClockVertex.Clock(x)) + # add the edge to the graph + add_edge!(inference_graph, arg_hyperedge) + end + # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the + # involved variables have the same clock. + InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge) + # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't + # add the edge, and instead add this to the `relative_hyperedges` mapping. + InferredClock.InferredDiscrete(i) => begin + relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i) + union!(relative_edge, arg_hyperedge) + end + end + end + + outdomain = output_timedomain(op) + @match outdomain begin + x::SciMLBase.AbstractClock => begin + push!(hyperedge, ClockVertex.Clock(x)) + end + InferredClock.Inferred() => nothing + InferredClock.InferredDiscrete(i) => begin + buffer = get(relative_hyperedges, i, nothing) + if buffer !== nothing + union!(hyperedge, buffer) + delete!(relative_hyperedges, i) + end + end + end + + for (_, relative_edge) in relative_hyperedges + add_edge!(inference_graph, relative_edge) + end + end + + add_edge!(inference_graph, hyperedge) +end + """ Update the equation-to-time domain mapping by inferring the time domain from the variables. """ function infer_clocks!(ci::ClockInference) (; ts, eq_domain, init_eq_domain, var_domain, inferred, inference_graph) = ci - (; var_to_diff, graph) = ts.structure sys = get_sys(ts) fullvars = StateSelection.get_fullvars(ts) isempty(inferred) && return ci - var_to_idx = Dict{SymbolicT, Int}(fullvars .=> eachindex(fullvars)) + var_to_idx = Dict{SymbolicT, Int}() + sizehint!(var_to_idx, length(fullvars)) + for (i, v) in enumerate(fullvars) + var_to_idx[v] = i + end # all shifted variables have the same clock as the unshifted variant for (i, v) in enumerate(fullvars) @@ -81,112 +200,8 @@ function infer_clocks!(ci::ClockInference) _ => nothing end end + infer_equation = InferEquationClosure(var_to_idx, inference_graph) - # preallocated buffers: - # variables in each equation - varsbuf = Set{SymbolicT}() - # variables in each argument to an operator - arg_varsbuf = Set{SymbolicT}() - # hyperedge for each equation - hyperedge = Set{ClockVertex.Type}() - # hyperedge for each argument to an operator - arg_hyperedge = Set{ClockVertex.Type}() - # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition - relative_hyperedges = Dict{Int, Set{ClockVertex.Type}}() - - function infer_equation(ieq, eq, is_initialization_equation) - empty!(varsbuf) - empty!(hyperedge) - # get variables in equation - SU.search_variables!(varsbuf, eq; is_atomic = MTKBase.OperatorIsAtomic{SU.Operator}()) - # add the equation to the hyperedge - eq_node = if is_initialization_equation - ClockVertex.InitEquation(ieq) - else - ClockVertex.Equation(ieq) - end - push!(hyperedge, eq_node) - for var in varsbuf - idx = get(var_to_idx, var, nothing) - # if this is just a single variable, add it to the hyperedge - if idx isa Int - push!(hyperedge, ClockVertex.Variable(idx)) - # we don't immediately `continue` here because this variable might be a - # `Sample` or similar and we want the clock information from it if it is. - end - # now we only care about synchronous operators - op, args = @match var begin - BSImpl.Term(; f, args) && if is_timevarying_operator(f)::Bool end => (f, args) - _ => continue - end - - # arguments and corresponding time domains - tdomains = input_timedomain(op)::Vector{InputTimeDomainElT} - nargs = length(args) - ndoms = length(tdomains) - if nargs != ndoms - throw(ArgumentError(""" - Operator $op applied to $nargs arguments $args but only returns $ndoms \ - domains $tdomains from `input_timedomain`. - """)) - end - - # each relative clock mapping is only valid per operator application - empty!(relative_hyperedges) - for (arg, domain) in zip(args, tdomains) - empty!(arg_varsbuf) - empty!(arg_hyperedge) - # get variables in argument - SU.search_variables!(arg_varsbuf, arg; is_atomic = MTKBase.OperatorIsAtomic{Union{Differential, MTKBase.Shift}}()) - # get hyperedge for involved variables - for v in arg_varsbuf - vidx = get(var_to_idx, v, nothing) - vidx === nothing && continue - push!(arg_hyperedge, ClockVertex.Variable(vidx)) - end - - @match domain begin - # If the time domain for this argument is a clock, then all variables in this edge have that clock. - x::SciMLBase.AbstractClock => begin - # add the clock to the edge - push!(arg_hyperedge, ClockVertex.Clock(x)) - # add the edge to the graph - add_edge!(inference_graph, arg_hyperedge) - end - # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the - # involved variables have the same clock. - InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge) - # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't - # add the edge, and instead add this to the `relative_hyperedges` mapping. - InferredClock.InferredDiscrete(i) => begin - relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i) - union!(relative_edge, arg_hyperedge) - end - end - end - - outdomain = output_timedomain(op) - @match outdomain begin - x::SciMLBase.AbstractClock => begin - push!(hyperedge, ClockVertex.Clock(x)) - end - InferredClock.Inferred() => nothing - InferredClock.InferredDiscrete(i) => begin - buffer = get(relative_hyperedges, i, nothing) - if buffer !== nothing - union!(hyperedge, buffer) - delete!(relative_hyperedges, i) - end - end - end - - for (_, relative_edge) in relative_hyperedges - add_edge!(inference_graph, relative_edge) - end - end - - add_edge!(inference_graph, hyperedge) - end for (ieq, eq) in enumerate(MTKBase.equations(sys)) infer_equation(ieq, eq, false) end @@ -212,7 +227,9 @@ function infer_clocks!(ci::ClockInference) """)) end - clock = partition[only(clockidxs)].:1 + clock = Moshi.Match.@match partition[only(clockidxs)] begin + ClockVertex.Clock(clk) => clk + end for vert in partition Moshi.Match.@match vert begin ClockVertex.Variable(i) => (var_domain[i] = clock) @@ -275,19 +292,15 @@ function split_system(ci::ClockInference{S}) where {S} # populates clock_to_id and id_to_clock # checks if there is a continuous_id (for some reason? clock to id does this too) for (i, d) in enumerate(eq_domain) - cid = let cid_counter = cid_counter, id_to_clock = id_to_clock, - continuous_id = continuous_id - - # Fill the clock_to_id dict as you go, - # ContinuousClock() => 1, ... - get!(clock_to_id, d) do - cid = (cid_counter[] += 1) - push!(id_to_clock, d) - if d == SciMLBase.ContinuousClock() - continuous_id[] = cid - end - cid + # We don't use `get!` here because that desperately wants to box things + cid = get(clock_to_id, d, 0) + if iszero(cid) + cid = (cid_counter[] += 1) + push!(id_to_clock, d) + if d === SciMLBase.ContinuousClock() + continuous_id[] = cid end + clock_to_id[d] = cid end eq_to_cid[i] = cid resize_or_push!(cid_to_eq, i, cid) diff --git a/lib/ModelingToolkitTearing/src/reassemble.jl b/lib/ModelingToolkitTearing/src/reassemble.jl index 1cf0131..b316ce6 100644 --- a/lib/ModelingToolkitTearing/src/reassemble.jl +++ b/lib/ModelingToolkitTearing/src/reassemble.jl @@ -1076,7 +1076,7 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState, dummy_sub = Dict{SymbolicT, SymbolicT}() if MTKBase.has_iv(state.sys) && MTKBase.get_iv(state.sys) !== nothing - iv = MTKBase.get_iv(state.sys) + iv = MTKBase.get_iv(state.sys)::SymbolicT if !StateSelection.is_only_discrete(state.structure) D = Differential(iv) else @@ -1089,29 +1089,50 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState, extra_unknowns = state.fullvars[extra_eqs_vars[2]] if StateSelection.is_only_discrete(state.structure) var_sccs = add_additional_history!( - state, var_eq_matching, full_var_eq_matching, var_sccs, iv) + state, var_eq_matching, full_var_eq_matching, var_sccs, iv::SymbolicT) end # Structural simplification - substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub, iv, D) - - var_sccs = generate_derivative_variables!( - state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs, mm, iv) - neweqs, solved_eqs, - eq_ordering, - var_ordering, - nelim_eq, - nelim_var = generate_system_equations!( - state, neweqs, var_eq_matching, full_var_eq_matching, - var_sccs, extra_eqs_vars, iv, D; simplify, inline_linear_sccs, - analytical_linear_scc_limit) - - state = reorder_vars!( - state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var) - # var_eq_matching and full_var_eq_matching are now invalidated - - sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs, - extra_unknowns, iv, D; array_hack) + if iv isa SymbolicT # Without iv we don't have derivatives + D = D::Union{Differential, Shift} + substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub, iv, D) + + var_sccs = generate_derivative_variables!( + state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs, mm, iv) + end + if iv isa SymbolicT + D = D::Union{Differential, Shift} + neweqs, solved_eqs, + eq_ordering, + var_ordering, + nelim_eq, + nelim_var = generate_system_equations!( + state, neweqs, var_eq_matching, full_var_eq_matching, + var_sccs, extra_eqs_vars, iv, D; simplify, inline_linear_sccs, + analytical_linear_scc_limit) + state = reorder_vars!( + state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var) + # var_eq_matching and full_var_eq_matching are now invalidated + + sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs, + extra_unknowns, iv, D; array_hack) + else + D = D::Nothing + neweqs, solved_eqs, + eq_ordering, + var_ordering, + nelim_eq, + nelim_var = generate_system_equations!( + state, neweqs, var_eq_matching, full_var_eq_matching, + var_sccs, extra_eqs_vars, iv, D; simplify, inline_linear_sccs, + analytical_linear_scc_limit) + state = reorder_vars!( + state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var) + # var_eq_matching and full_var_eq_matching are now invalidated + + sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs, + extra_unknowns, iv, D; array_hack) + end @set! state.sys = sys @set! sys.tearing_state = state diff --git a/lib/ModelingToolkitTearing/src/stateselection_interface.jl b/lib/ModelingToolkitTearing/src/stateselection_interface.jl index bab72ef..60385b0 100644 --- a/lib/ModelingToolkitTearing/src/stateselection_interface.jl +++ b/lib/ModelingToolkitTearing/src/stateselection_interface.jl @@ -134,8 +134,9 @@ function StateSelection.find_eq_solvables!(state::TearingState, ieq, to_rm = Int # When the expression is linear with numeric `a`, then we can safely # only consider `b` for the following iterations. term = b - if SU._isone(abs(unwrap_const(a))) - coeffs === nothing || push!(coeffs, convert(Int, unwrap_const(a))) + a_is_one = SU._isone(a) + if a_is_one || manual_dispatch_isabsone(unwrap_const(a)) + coeffs === nothing || push!(coeffs, a_is_one ? 1 : -1) else all_int_vars = false conservative && continue @@ -156,3 +157,22 @@ function StateSelection.find_eq_solvables!(state::TearingState, ieq, to_rm = Int all_int_vars, term end +function manual_dispatch_isabsone(@nospecialize(x)) + if x isa Int + return isone(abs(x)) + elseif x isa BigInt + return isone(abs(x)) + elseif x isa Float64 + return isone(abs(x)) + elseif x isa Float32 + return isone(abs(x)) + elseif x isa BigFloat + return isone(abs(x)) + elseif x isa Rational{Int} + return isone(abs(x)) + elseif x isa Rational{BigInt} + return isone(abs(x)) + else + return isone(abs(x))::Bool + end +end diff --git a/lib/ModelingToolkitTearing/src/tearingstate.jl b/lib/ModelingToolkitTearing/src/tearingstate.jl index b119233..f677a71 100644 --- a/lib/ModelingToolkitTearing/src/tearingstate.jl +++ b/lib/ModelingToolkitTearing/src/tearingstate.jl @@ -89,8 +89,7 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true) sys = MTKBase.discrete_unknowns_to_parameters(sys) sys = MTKBase.discover_globalscoped(sys) MTKBase.check_no_parameter_equations(sys) - ivs = independent_variables(sys) - iv = length(ivs) == 1 ? ivs[1] : nothing + iv = MTKBase.get_iv(sys) # flatten array equations eqs = MTKBase.flatten_equations(equations(sys)) original_eqs = copy(eqs) @@ -134,7 +133,7 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true) # TODO: Can we handle this without `isparameter`? if v in ps - if is_time_dependent_parameter(v, ps, iv) && + if iv isa SymbolicT && is_time_dependent_parameter(v, ps, iv) && !haskey(param_derivative_map, Differential(iv)(v)) && !(Differential(iv)(v) in no_deriv_params) # Parameter derivatives default to zero - they stay constant # between callbacks @@ -143,13 +142,10 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true) continue end - isequal(v, iv) && continue - MTKBase.isdelay(v, iv) && continue + iv isa SymbolicT && isequal(v, iv) && continue + iv isa SymbolicT && MTKBase.isdelay(v, iv) && continue if !in(v, dvs) - isvalid = iscall(v) && - (operation(v) isa Shift || isempty(arguments(v)) || - is_transparent_operator(operation(v))) isvalid = @match v begin BSImpl.Term(; f, args) => f isa Shift || isempty(args) || f isa SU.Operator && is_transparent_operator(f)::Bool _ => false @@ -180,9 +176,11 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true) addvar!(v, VARIABLE) @match v begin BSImpl.Term(; f, args) && if f isa SU.Operator && - !(f isa Differential) && (it = input_timedomain(v)::Vector{InputTimeDomainElT}) !== nothing + !(f isa Differential) end => begin - for (v′, td) in zip(args, it) + it = input_timedomain(v)::Vector{InputTimeDomainElT} + for (i, td) in enumerate(it) + v′ = args[i] addvar!(setmetadata(v′, MTKBase.VariableTimeDomain, td), VARIABLE) end end @@ -209,8 +207,8 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true) addvar!(vi, VARIABLE) end else - vv = collect(v) - union!(incidence, vv)::Array{SymbolicT} + vv = collect(v)::Array{SymbolicT} + union!(incidence, vv) for vi in vv addvar!(vi, VARIABLE) end diff --git a/lib/ModelingToolkitTearing/src/utils.jl b/lib/ModelingToolkitTearing/src/utils.jl index a455f47..8c86347 100644 --- a/lib/ModelingToolkitTearing/src/utils.jl +++ b/lib/ModelingToolkitTearing/src/utils.jl @@ -31,10 +31,50 @@ function descend_lower_shift_varname(var, iv) end end -function is_time_dependent_parameter(p, allps, iv) - return iv !== nothing && p in allps && iscall(p) && - (operation(p) === getindex && - is_time_dependent_parameter(arguments(p)[1], allps, iv) || - (args = arguments(p); length(args)) == 1 && isequal(only(args), iv)) +function is_time_dependent_parameter(p::SymbolicT, allps::Set{SymbolicT}, iv::SymbolicT) + return p in allps && @match p begin + BSImpl.Term(; f, args) => begin + farg = args[1] + f === getindex && is_time_dependent_parameter(farg, allps, iv) || + length(args) == 1 && isequal(farg, iv) + end + _ => false + end end +const UNION_SPLIT_VAR_FIRST_ERR = """ +The first argument to `@union_split_var` must be of the form `var::Union{T1, T2}` where \ +`var` is a single variable (not an expression). +""" + +""" + @union_split_var var::Union{T1, T2} begin; #= ... =#; end + +Manually dispatch the `begin..end` block based on the given type-annotation for `var`. +`var` cannot be an expression. +""" +macro union_split_var(annotated_var::Expr, block::Expr) + @assert Meta.isexpr(annotated_var, :(::)) UNION_SPLIT_VAR_FIRST_ERR + @assert length(annotated_var.args) == 2 UNION_SPLIT_VAR_FIRST_ERR + var, type = annotated_var.args + @assert var isa Symbol UNION_SPLIT_VAR_FIRST_ERR + var = var::Symbol + @assert Meta.isexpr(type, :curly) UNION_SPLIT_VAR_FIRST_ERR + @assert type[1] == :Union UNION_SPLIT_VAR_FIRST_ERR + + variants = @view type.args[2:end] + N = length(variants) + expr = Expr(:if) + cur_expr = expr + for (i, variant) in enumerate(variants) + push!(cur_expr.args, :($var isa $variant)) + push!(cur_expr.args, block) + i == N && continue + new_expr = Expr(:elseif) + push!(cur_expr.args, new_expr) + cur_expr = new_expr + end + push!(cur_expr.args, :(error("Unexpected type $(typeof($var)) for variable $($var)"))) + + return esc(expr) +end diff --git a/src/StateSelection.jl b/src/StateSelection.jl index 6f7bcc0..66a750c 100644 --- a/src/StateSelection.jl +++ b/src/StateSelection.jl @@ -2,7 +2,6 @@ module StateSelection using DocStringExtensions using Setfield: @set!, @set -using UnPack: @unpack using Graphs import SparseArrays import OrderedCollections: OrderedSet diff --git a/src/debug.jl b/src/debug.jl index d9b5a78..93caa00 100644 --- a/src/debug.jl +++ b/src/debug.jl @@ -98,7 +98,7 @@ end Base.show(io::IO, inc::IncidenceMarker) = print(io, inc.active ? "x" : " ") function Base.show(io::IO, mime::MIME"text/plain", s::SystemStructure) - @unpack graph, solvable_graph, var_to_diff, eq_to_diff = s + (; graph, solvable_graph, var_to_diff, eq_to_diff) = s if !get(io, :limit, true) || !get(io, :mtk_limit, true) print(io, "SystemStructure with ", length(s.graph.fadjlist), " equations and ", isa(s.graph.badjlist, Int) ? s.graph.badjlist : length(s.graph.badjlist), @@ -139,7 +139,7 @@ end function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure) s = ms.structure - @unpack graph, solvable_graph, var_to_diff, eq_to_diff = s + (; graph, solvable_graph, var_to_diff, eq_to_diff) = s print(io, "Matched SystemStructure with ", length(graph.fadjlist), " equations and ", isa(graph.badjlist, Int) ? graph.badjlist : length(graph.badjlist), " variables\n") diff --git a/src/modia_tearing.jl b/src/modia_tearing.jl index aa6b17f..f49a076 100644 --- a/src/modia_tearing.jl +++ b/src/modia_tearing.jl @@ -91,7 +91,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing, # to have optimal solutions that cannot be found by this process. We will not # find them here [TODO: It would be good to have an explicit example of this.] - @unpack graph, solvable_graph = structure + (; graph, solvable_graph) = structure var_eq_matching = maximal_matching(graph, U, srcfilter=eqfilter, dstfilter=varfilter) @@ -101,7 +101,9 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing, full_var_eq_matching = copy(var_eq_matching) var_sccs = find_var_sccs(graph, var_eq_matching) vargraph = DiCMOBiGraph{true}(graph) - ict = IncrementalCycleTracker(vargraph; dir = :in) + # Inlining `IncrementalCycleTracker(vargraph; dir = :in)` because calling it + # directly doesn't infer. + ict = Graphs.DenseGraphICT_BFGT_N{:in}(vargraph) ieqs = Int[] filtered_vars = BitSet() diff --git a/src/pantelides.jl b/src/pantelides.jl index 0e9c21f..0d7e0f1 100644 --- a/src/pantelides.jl +++ b/src/pantelides.jl @@ -19,7 +19,7 @@ for every variable, indicating whether it is considered "highest-differentiated" determines whether it should be included in the list. """ function computed_highest_diff_variables(structure::SystemStructure, varfilter) - @unpack graph, var_to_diff = structure + (; graph, var_to_diff) = structure nvars = length(var_to_diff) varwhitelist = falses(nvars) @@ -67,7 +67,7 @@ _canchoose(diffvars::Nothing, var::Integer) = true Perform Pantelides algorithm. """ function pantelides!(state::TransformationState; finalize = true, maxiters = 8000, eqfilter = eq->true, varfilter = var->true, _...) - @unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure + (; graph, solvable_graph, var_to_diff, eq_to_diff) = state.structure neqs = nsrcs(graph) nvars = nv(var_to_diff) vcolor = falses(nvars) diff --git a/src/partial_state_selection.jl b/src/partial_state_selection.jl index 0ac5764..2ee1ce3 100644 --- a/src/partial_state_selection.jl +++ b/src/partial_state_selection.jl @@ -35,7 +35,7 @@ struct DiffData end function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, diff_data::Union{Nothing, DiffData}=nothing) - @unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure + (; eq_to_diff, var_to_diff, graph, solvable_graph) = structure # var_eq_matching is a maximal matching on the top-differentiated variables. # Find Strongly connected components. Note that after pantelides, we expect @@ -134,7 +134,7 @@ function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, diff end function partial_state_selection_graph!(structure::SystemStructure, var_eq_matching) - @unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure + (; eq_to_diff, var_to_diff, graph, solvable_graph) = structure eq_to_diff = complete(eq_to_diff) inv_eqlevel = map(1:nsrcs(graph)) do eq @@ -204,7 +204,7 @@ function dummy_derivative_graph!( structure::SystemStructure, var_eq_matching, jac = nothing, state_priority = nothing, ::Val{log} = Val(false); tearing_alg::TearingAlgorithm = DummyDerivativeTearing(), kwargs...) where {log} - @unpack eq_to_diff, var_to_diff, graph = structure + (; eq_to_diff, var_to_diff, graph) = structure diff_to_eq = invview(eq_to_diff) diff_to_var = invview(var_to_diff) invgraph = invview(graph) @@ -368,7 +368,7 @@ function dummy_derivative_graph!( end function is_present(structure, v)::Bool - @unpack var_to_diff, graph = structure + (; var_to_diff, graph) = structure while true # if a higher derivative is present, then it's present isempty(𝑑neighbors(graph, v)) || return true @@ -386,13 +386,13 @@ end # We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with # actually differentiated variables. function isdiffed((structure, dummy_derivatives), v)::Bool - @unpack var_to_diff, graph = structure + (; var_to_diff, graph) = structure diff_to_var = invview(var_to_diff) diff_to_var[v] !== nothing && is_some_diff(structure, dummy_derivatives, v) end function tearing_with_dummy_derivatives(structure, dummy_derivatives) - @unpack var_to_diff = structure + (; var_to_diff) = structure # We can eliminate variables that are not selected (differential # variables). Selected unknowns are differentiated variables that are not # dummy derivatives. @@ -419,7 +419,7 @@ end struct DummyDerivativeTearing <: TearingAlgorithm end function (::DummyDerivativeTearing)(structure::SystemStructure, dummy_derivatives::Union{BitSet, Tuple{}} = ()) - @unpack var_to_diff = structure + (; var_to_diff) = structure # We can eliminate variables that are not selected (differential # variables). Selected unknowns are differentiated variables that are not # dummy derivatives. diff --git a/src/singularity_removal.jl b/src/singularity_removal.jl index 04546c1..48b759b 100644 --- a/src/singularity_removal.jl +++ b/src/singularity_removal.jl @@ -22,7 +22,7 @@ function structural_singularity_removal!(state::TransformationState; return mm # No linear subsystems end - @unpack graph, var_to_diff, solvable_graph = state.structure + (; graph, var_to_diff, solvable_graph) = state.structure mm = structural_singularity_removal!(state, mm; variable_underconstrained!) s = state.structure for (ei, e) in enumerate(mm.nzrows) @@ -169,7 +169,7 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible end function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti} - @unpack graph, var_to_diff = structure + (; graph, var_to_diff) = structure mm = copy(mm_orig) linear_equations_set = BitSet(mm_orig.nzrows) @@ -259,7 +259,7 @@ function do_bareiss!(M, Mold, is_linear_variables, is_highest_diff) end function force_var_to_zero!(structure::SystemStructure, ils::SparseMatrixCLIL, v::Int) - @unpack graph, solvable_graph, eq_to_diff = structure + (; graph, solvable_graph, eq_to_diff) = structure @set! ils.nparentrows += 1 push!(ils.nzrows, ils.nparentrows) push!(ils.row_cols, [v]) @@ -274,8 +274,8 @@ end function structural_singularity_removal!(state::TransformationState, ils::SparseMatrixCLIL; variable_underconstrained! = force_var_to_zero!) - @unpack structure = state - @unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure + (; structure) = state + (; graph, solvable_graph, var_to_diff, eq_to_diff) = state.structure # Step 1: Perform Bareiss factorization on the adjacency matrix of the linear # subsystem of the system we're interested in. # diff --git a/src/utils.jl b/src/utils.jl index 416d49f..40fda28 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -112,7 +112,7 @@ instead of throwing an error. The singular case will print a warning. """ function check_consistency(state::TransformationState, orig_inputs; nothrow = false) neqs = n_concrete_eqs(state) - @unpack graph, var_to_diff = state.structure + (; graph, var_to_diff) = state.structure highest_vars = computed_highest_diff_variables(complete!(state.structure)) n_highest_vars = 0 for (v, h) in enumerate(highest_vars)