Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[],
collect_vars!(dvs, params, eq, iv)
end
pre_params = filter(haspre ∘ value, params)
discrete_parameters = gather_array_params(OrderedSet(discrete_parameters))
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
discrete_parameters = collect(discrete_parameters)
discretes = map(tovar, discrete_parameters)
dvs = collect(dvs)
_dvs = map(default_toterm, dvs)
Expand Down Expand Up @@ -904,7 +906,12 @@ function compile_equational_affect(
obseqs, Dict([p => unPre(p) for p in parameters(affsys)]))
rhss = map(x -> x.rhs, update_eqs)
lhss = map(x -> x.lhs, update_eqs)
is_p = [lhs in Set(ps_to_update) for lhs in lhss]
update_ps_set = Set(ps_to_update)
is_p = map(lhss) do lhs
lhs in update_ps_set ||
iscall(lhs) && operation(lhs) === getindex &&
arguments(lhs)[1] in update_ps_set
end
is_u = [lhs in Set(dvs_to_update) for lhs in lhss]
dvs = unknowns(sys)
ps = parameters(sys)
Expand Down
10 changes: 8 additions & 2 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,14 @@ function IndexCache(sys::AbstractSystem)
end

for sym in discs
is_parameter(sys, sym) ||
error("Expected discrete variable $sym in callback to be a parameter")
if !is_parameter(sys, sym)
if iscall(sym) && operation(sym) === getindex &&
is_parameter(sys, arguments(sym)[1])
sym = arguments(sym)[1]
else
error("Expected discrete variable $sym in callback to be a parameter")
end
end

# Only `foo(t)`-esque parameters can be saved
if iscall(sym) && length(arguments(sym)) == 1 &&
Expand Down
18 changes: 15 additions & 3 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,10 @@ function Base.getindex(ngi::NestedGetIndex, idx::Tuple)
i, j, k... = idx
return ngi.x[i][j][k...]
end
function Base.getindex(ngi::NestedGetIndex, idx::NTuple{2})
i, j = idx
return ngi.x[i][j]
end

# Required for DiffEqArray constructor to work during interpolation
Base.size(::NestedGetIndex) = ()
Expand All @@ -826,19 +830,27 @@ function SciMLBase.create_parameter_timeseries_collection(
isempty(ps.discrete) && return nothing
num_discretes = only(blocksize(ps.discrete[1]))
buffers = []
partition_type = Tuple{(typeof(parent(buf)) for buf in ps.discrete)...}
partition_type = typeof(SciMLBase.get_saveable_values(sys, ps, 1))
for i in 1:num_discretes
ts = eltype(tspan)[]
us = NestedGetIndex{partition_type}[]
us = partition_type[]
push!(buffers, DiffEqArray(us, ts, (1, 1)))
end

return ParameterTimeseriesCollection(Tuple(buffers), copy(ps))
end

@inline __get_blocks(tsidx::Int) = ()
@inline function __get_blocks(tsidx::Int, buffer::BlockedArray, buffers...)
(buffer[Block(tsidx)], __get_blocks(tsidx, buffers...)...)
end
@inline function __get_blocks(tsidx::Int, buffer::BlockedArray{<:AbstractArray}, buffers...)
(copy.(buffer[Block(tsidx)]), __get_blocks(tsidx, buffers...)...)
end

function SciMLBase.get_saveable_values(
sys::AbstractSystem, ps::MTKParameters, timeseries_idx)
return NestedGetIndex(Tuple(buffer[Block(timeseries_idx)] for buffer in ps.discrete))
return NestedGetIndex(__get_blocks(timeseries_idx, ps.discrete...))
end

function save_callback_discretes!(integ::SciMLBase.DEIntegrator, callback)
Expand Down
80 changes: 80 additions & 0 deletions test/symbolic_events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1455,3 +1455,83 @@ end
Pre(X) +
10.0])
end

@testset "Issue#3990: Scalarized array passed to `discrete_parameters` of symbolic affect" begin
N = 2
@parameters v(t)[1:N]
@parameters M(t)[1:N, 1:N]

@variables x(t)

Mini = rand(N, N) ./ (N^2)
vini = vec(sum(Mini, dims = 1))

v_eq = [D(x) ~ x * Symbolics.scalarize(sum(v))]
M_eq = [D(x) ~ x * Symbolics.scalarize(sum(M))]

v_event = ModelingToolkit.SymbolicDiscreteCallback(
1.0,
[v ~ -Pre(v)],
discrete_parameters = [v]
)

M_event = ModelingToolkit.SymbolicDiscreteCallback(
1.0,
[M ~ -Pre(M)],
discrete_parameters = [M]
)

@mtkcompile v_sys = System(v_eq, t; discrete_events = v_event)
@mtkcompile M_sys = System(M_eq, t; discrete_events = M_event)

u0p0_map = Dict(x => 1.0, M => Mini, v => vini)

v_prob = ODEProblem(v_sys, u0p0_map, (0.0, 2.5))
M_prob = ODEProblem(M_sys, u0p0_map, (0.0, 2.5))

v_sol = solve(v_prob, Tsit5())
M_sol = solve(M_prob, Tsit5())

@test v_sol[v] ≈ [vini, -vini, vini]
@test M_sol[M] ≈ [Mini, -Mini, Mini]
end

@testset "Issue#3990: Scalarized array passed to `discrete_parameters` of symbolic affect" begin
N = 2
@parameters v(t)[1:N]
@parameters M(t)[1:N, 1:N]

@variables x(t)

Mini = rand(N, N) ./ (N^2)
vini = vec(sum(Mini, dims = 1))

v_eq = [D(x) ~ x * Symbolics.scalarize(sum(v))]
M_eq = [D(x) ~ x * Symbolics.scalarize(sum(M))]

v_event = ModelingToolkit.SymbolicDiscreteCallback(
1.0,
[v ~ -Pre(v)],
discrete_parameters = collect(v)
)

M_event = ModelingToolkit.SymbolicDiscreteCallback(
1.0,
[M ~ -Pre(M)],
discrete_parameters = vec(collect(M))
)

@mtkcompile v_sys = System(v_eq, t; discrete_events = v_event)
@mtkcompile M_sys = System(M_eq, t; discrete_events = M_event)

u0p0_map = Dict(x => 1.0, M => Mini, v => vini)

v_prob = ODEProblem(v_sys, u0p0_map, (0.0, 2.5))
M_prob = ODEProblem(M_sys, u0p0_map, (0.0, 2.5))

v_sol = solve(v_prob, Tsit5())
M_sol = solve(M_prob, Tsit5())

@test v_sol[v] ≈ [vini, -vini, vini]
@test M_sol[M] ≈ [Mini, -Mini, Mini]
end
Loading