Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions ext/MTKFMIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,

# instance management callback which deallocates the instance when
# necessary and notifies the FMU of completed integrator steps
finalize_affect = MTK.FunctionalAffect(fmiFinalize!, [], [wrapper], [])
step_affect = MTK.FunctionalAffect(Returns(nothing), [], [], [])
finalize_affect = MTK.ImperativeAffect(fmiFinalize!; observed = (; wrapper))
step_affect = MTK.ImperativeAffect(Returns((;)))
instance_management_callback = MTK.SymbolicDiscreteCallback(
(t == t - 1), step_affect; finalize = finalize_affect, reinitializealg = SciMLBase.NoInit())

Expand Down Expand Up @@ -273,7 +273,7 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
end
initialize_affect = MTK.ImperativeAffect(fmiCSInitialize!; observed = cb_observed,
modified = cb_modified, ctx = _functor)
finalize_affect = MTK.FunctionalAffect(fmiFinalize!, [], [wrapper], [])
finalize_affect = MTK.ImperativeAffect(fmiFinalize!; observed = (; wrapper))
# the callback affect performs the stepping
step_affect = MTK.ImperativeAffect(
fmiCSStep!; observed = cb_observed, modified = cb_modified, ctx = _functor)
Expand Down Expand Up @@ -708,15 +708,15 @@ end
"""
$(TYPEDSIGNATURES)

An affect function for use inside a `FunctionalAffect`. This should be triggered at the
An affect function for use inside an `ImperativeAffect`. This should be triggered at the
end of the solve, regardless of whether it succeeded or failed. Expects `p` to be a
1-length array containing the index of the instance wrapper (`FMI2InstanceWrapper` or
`FMI3InstanceWrapper`) in the parameter object.
"""
function fmiFinalize!(integrator, u, p, ctx)
wrapper_idx = p[1]
wrapper = integrator.ps[wrapper_idx]
function fmiFinalize!(m, o, ctx, integrator)
wrapper = o.wrapper
reset_instance!(wrapper)
return (;)
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/linearization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true
end
(all(values(outputset)) || error(
"Some specified outputs were not found in system. The following Dict indicates the found variables ",
outputset))
outputset))
end
state, orig_inputs
end
Expand Down
112 changes: 12 additions & 100 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,7 @@
abstract type AbstractCallback end

struct FunctionalAffect
f::Any
sts::Vector
sts_syms::Vector{Symbol}
pars::Vector
pars_syms::Vector{Symbol}
discretes::Vector
ctx::Any
end

function FunctionalAffect(f, sts, pars, discretes, ctx = nothing)
# sts & pars contain either pairs: resistor.R => R, or Syms: R
vs = [x isa Pair ? x.first : x for x in sts]
vs_syms = Symbol[x isa Pair ? Symbol(x.second) : getname(x) for x in sts]
length(vs_syms) == length(unique(vs_syms)) || error("Variables are not unique")

ps = [x isa Pair ? x.first : x for x in pars]
ps_syms = Symbol[x isa Pair ? Symbol(x.second) : getname(x) for x in pars]
length(ps_syms) == length(unique(ps_syms)) || error("Parameters are not unique")

FunctionalAffect(f, vs, vs_syms, ps, ps_syms, discretes, ctx)
end

function FunctionalAffect(; f, sts, pars, discretes, ctx = nothing)
FunctionalAffect(f, sts, pars, discretes, ctx)
end

func(a::FunctionalAffect) = a.f
context(a::FunctionalAffect) = a.ctx
parameters(a::FunctionalAffect) = a.pars
parameters_syms(a::FunctionalAffect) = a.pars_syms
unknowns(a::FunctionalAffect) = a.sts
unknowns_syms(a::FunctionalAffect) = a.sts_syms
discretes(a::FunctionalAffect) = a.discretes

function Base.:(==)(a1::FunctionalAffect, a2::FunctionalAffect)
isequal(a1.f, a2.f) && isequal(a1.sts, a2.sts) && isequal(a1.pars, a2.pars) &&
isequal(a1.sts_syms, a2.sts_syms) && isequal(a1.pars_syms, a2.pars_syms) &&
isequal(a1.ctx, a2.ctx)
end

function Base.hash(a::FunctionalAffect, s::UInt)
s = hash(a.f, s)
s = hash(a.sts, s)
s = hash(a.sts_syms, s)
s = hash(a.pars, s)
s = hash(a.pars_syms, s)
s = hash(a.discretes, s)
hash(a.ctx, s)
end

function has_functional_affect(cb)
(affects(cb) isa FunctionalAffect || affects(cb) isa ImperativeAffect)
affects(cb) isa ImperativeAffect
end

struct AffectSystem
Expand Down Expand Up @@ -97,7 +46,7 @@ function Base.hash(a::AffectSystem, s::UInt)
hash(aff_to_sys(a), s)
end

function vars!(vars, aff::Union{FunctionalAffect, AffectSystem}; op = Differential)
function vars!(vars, aff::AffectSystem; op = Differential)
for var in Iterators.flatten((unknowns(aff), parameters(aff), discretes(aff)))
vars!(vars, var)
end
Expand Down Expand Up @@ -161,7 +110,7 @@ end
###############################
###### Continuous events ######
###############################
const Affect = Union{AffectSystem, FunctionalAffect, ImperativeAffect}
const Affect = Union{AffectSystem, ImperativeAffect}

"""
SymbolicContinuousCallback(eqs::Vector{Equation}, affect = nothing, iv = nothing;
Expand Down Expand Up @@ -233,7 +182,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
conditions = (conditions isa AbstractVector) ? conditions : [conditions]

if isnothing(reinitializealg)
if any(a -> (a isa FunctionalAffect || a isa ImperativeAffect),
if any(a -> a isa ImperativeAffect,
[affect, affect_neg, initialize, finalize])
reinitializealg = SciMLBase.CheckInit()
else
Expand Down Expand Up @@ -263,8 +212,8 @@ function SymbolicContinuousCallback(cb::Tuple, args...; kwargs...)
end

make_affect(affect::Nothing; kwargs...) = nothing
make_affect(affect::Tuple; kwargs...) = FunctionalAffect(affect...)
make_affect(affect::NamedTuple; kwargs...) = FunctionalAffect(; affect...)
make_affect(affect::Tuple; kwargs...) = ImperativeAffect(affect...)
make_affect(affect::NamedTuple; kwargs...) = ImperativeAffect(; affect...)
make_affect(affect::Affect; kwargs...) = affect

function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
Expand Down Expand Up @@ -446,7 +395,7 @@ struct SymbolicDiscreteCallback <: AbstractCallback
c = is_timed_condition(condition) ? condition : value(scalarize(condition))

if isnothing(reinitializealg)
if any(a -> (a isa FunctionalAffect || a isa ImperativeAffect),
if any(a -> a isa ImperativeAffect,
[affect, initialize, finalize])
reinitializealg = SciMLBase.CheckInit()
else
Expand Down Expand Up @@ -498,16 +447,6 @@ end
############################################
########## Namespacing Utilities ###########
############################################
function namespace_affects(affect::FunctionalAffect, s)
FunctionalAffect(func(affect),
renamespace.((s,), unknowns(affect)),
unknowns_syms(affect),
renamespace.((s,), parameters(affect)),
parameters_syms(affect),
renamespace.((s,), discretes(affect)),
context(affect))
end

function namespace_affects(affect::AffectSystem, s)
AffectSystem(renamespace(s, system(affect)),
renamespace.((s,), unknowns(affect)),
Expand Down Expand Up @@ -652,36 +591,6 @@ function compile_condition(
return CompiledCondition{is_discrete(cbs)}(fs)
end

"""
Compile user-defined functional affect.
"""
function compile_functional_affect(affect::FunctionalAffect, sys; kwargs...)
dvs = unknowns(sys)
ps = parameters(sys)
dvs_ind = Dict(reverse(en) for en in enumerate(dvs))
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))

if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
p_inds = [(pind = parameter_index(sys, sym)) === nothing ? sym : pind
for sym in parameters(affect)]
else
ps_ind = Dict(reverse(en) for en in enumerate(ps))
p_inds = map(sym -> get(ps_ind, sym, sym), parameters(affect))
end
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
# (MTK should keep these symbols)
u = filter(x -> !isnothing(x[2]), collect(zip(unknowns_syms(affect), v_inds))) |>
NamedTuple
p = filter(x -> !isnothing(x[2]), collect(zip(parameters_syms(affect), p_inds))) |>
NamedTuple

let u = u, p = p, user_affect = func(affect), ctx = context(affect)
(integ) -> begin
user_affect(integ, u, p, ctx)
end
end
end

is_discrete(cb::AbstractCallback) = cb isa SymbolicDiscreteCallback
is_discrete(cb::Vector{<:AbstractCallback}) = eltype(cb) isa SymbolicDiscreteCallback

Expand Down Expand Up @@ -837,7 +746,7 @@ function compile_affect(
elseif aff isa AffectSystem
f = compile_equational_affect(aff, sys; kwargs...)
wrap_save_discretes(f, save_idxs)
elseif aff isa FunctionalAffect || aff isa ImperativeAffect
elseif aff isa ImperativeAffect
f = compile_functional_affect(aff, sys; kwargs...)
wrap_save_discretes(f, save_idxs)
end
Expand Down Expand Up @@ -946,7 +855,8 @@ function compile_equational_affect(
end
else
return let dvs_to_update = dvs_to_update, aff_map = aff_map, sys_map = sys_map,
affsys = affsys, ps_to_update = ps_to_update, aff = aff, sys = sys
affsys = affsys, ps_to_update = ps_to_update, aff = aff, sys = sys,
reset_jumps = reset_jumps

dvs_to_access = [aff_map[u] for u in unknowns(affsys)]
ps_to_access = [unPre(p) for p in parameters(affsys)]
Expand Down Expand Up @@ -979,6 +889,8 @@ function compile_equational_affect(

u_setter!(integ, u_getter(affsol))
p_setter!(integ, p_getter(affsol))

reset_jumps && reset_aggregated_jumps!(integ)
end
end
end
Expand Down
12 changes: 9 additions & 3 deletions src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Where we use Setfield to copy the tuple `m` with a new value for `x`, then retur
`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it
in the returned tuple, in which case the associated field will not be updated.
"""
@kwdef struct ImperativeAffect
struct ImperativeAffect
f::Any
obs::Vector
obs_syms::Vector{Symbol}
Expand Down Expand Up @@ -63,6 +63,9 @@ function ImperativeAffect(
ImperativeAffect(
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
end
function ImperativeAffect(; f, kwargs...)
ImperativeAffect(f; kwargs...)
end

function Base.show(io::IO, mfa::ImperativeAffect)
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
Expand Down Expand Up @@ -164,7 +167,8 @@ function check_assignable(sys, sym)
end
end

function compile_functional_affect(affect::ImperativeAffect, sys; kwargs...)
function compile_functional_affect(
affect::ImperativeAffect, sys; reset_jumps = false, kwargs...)
#=
Implementation sketch:
generate observed function (oop), should save to a component array under obs_syms
Expand Down Expand Up @@ -244,7 +248,7 @@ function compile_functional_affect(affect::ImperativeAffect, sys; kwargs...)

upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))

let user_affect = func(affect), ctx = context(affect)
let user_affect = func(affect), ctx = context(affect), reset_jumps = reset_jumps
@inline function (integ)
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
Expand All @@ -259,6 +263,8 @@ function compile_functional_affect(affect::ImperativeAffect, sys; kwargs...)

# write the new values back to the integrator
_generated_writeback(integ, upd_funs, upd_vals)

reset_jumps && reset_aggregated_jumps!(integ)
end
end
end
Expand Down
3 changes: 1 addition & 2 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ function IndexCache(sys::AbstractSystem)
affs = [affs]
end
for affect in affs
if affect isa AffectSystem || affect isa FunctionalAffect ||
affect isa ImperativeAffect
if affect isa AffectSystem || affect isa ImperativeAffect
union!(discs, unwrap.(discretes(affect)))
elseif isnothing(affect)
continue
Expand Down
Loading
Loading