From 894c135f797acb8384408bc3e4b9b23231adef90 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 8 Oct 2024 15:14:04 +0530 Subject: [PATCH 1/4] fix: detect observed variables and dependent parameters dependent on discrete parameters --- src/systems/abstractsystem.jl | 13 ++++-- src/systems/diffeqs/odesystem.jl | 11 ++++- src/systems/index_cache.jl | 76 ++++++++++++++++++++------------ test/odesystem.jl | 11 +++++ 4 files changed, 77 insertions(+), 34 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index f396a9830a..15fc54e9e9 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -731,7 +731,7 @@ end function has_observed_with_lhs(sys, sym) has_observed(sys) || return false if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - return any(isequal(sym), ic.observed_syms) + return haskey(ic.observed_syms_to_timeseries, sym) else return any(isequal(sym), [eq.lhs for eq in observed(sys)]) end @@ -740,7 +740,7 @@ end function has_parameter_dependency_with_lhs(sys, sym) has_parameter_dependencies(sys) || return false if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - return any(isequal(sym), ic.dependent_pars) + return haskey(ic.dependent_pars_to_timeseries, unwrap(sym)) else return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)]) end @@ -762,11 +762,16 @@ for traitT in [ allsyms = vars(sym; op = Symbolics.Operator) for s in allsyms s = unwrap(s) - if is_variable(sys, s) || is_independent_variable(sys, s) || - has_observed_with_lhs(sys, s) + if is_variable(sys, s) || is_independent_variable(sys, s) push!(ts_idxs, ContinuousTimeseries()) elseif is_timeseries_parameter(sys, s) push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx) + elseif has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + if (ts = get(ic.observed_syms_to_timeseries, s, nothing)) !== nothing + union!(ts_idxs, ts) + elseif (ts = get(ic.dependent_pars_to_timeseries, s, nothing)) !== nothing + union!(ts_idxs, ts) + end end end end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 5ad275906d..b8dee2bac7 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -490,7 +490,16 @@ function build_explicit_observed_function(sys, ts; ivs = independent_variables(sys) dep_vars = scalarize(setdiff(vars, ivs)) - obs = param_only ? Equation[] : observed(sys) + obs = observed(sys) + if param_only + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + obs = filter(obs) do eq + !(ContinuousTimeseries() in ic.observed_syms_to_timeseries[eq.lhs]) + end + else + obs = Equation[] + end + end cs = collect_constants(obs) if !isempty(cs) > 0 diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 00f7837407..7a376aa1f2 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -38,6 +38,7 @@ const UnknownIndexMap = Dict{ BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}} const TunableIndexMap = Dict{BasicSymbolic, Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}} +const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}} struct IndexCache unknown_idx::UnknownIndexMap @@ -48,8 +49,9 @@ struct IndexCache tunable_idx::TunableIndexMap constant_idx::ParamIndexMap nonnumeric_idx::NonnumericMap - observed_syms::Set{BasicSymbolic} - dependent_pars::Set{Union{BasicSymbolic, CallWithMetadata}} + observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType} + dependent_pars_to_timeseries::Dict{ + Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_size::BufferTemplate constant_buffer_sizes::Vector{BufferTemplate} @@ -91,20 +93,6 @@ function IndexCache(sys::AbstractSystem) end end - observed_syms = Set{BasicSymbolic}() - for eq in observed(sys) - if symbolic_type(eq.lhs) != NotSymbolic() - sym = eq.lhs - ttsym = default_toterm(sym) - rsym = renamespace(sys, sym) - rttsym = renamespace(sys, ttsym) - push!(observed_syms, sym) - push!(observed_syms, ttsym) - push!(observed_syms, rsym) - push!(observed_syms, rttsym) - end - end - tunable_buffers = Dict{Any, Set{BasicSymbolic}}() constant_buffers = Dict{Any, Set{BasicSymbolic}}() nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}() @@ -267,29 +255,59 @@ function IndexCache(sys::AbstractSystem) end end - for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs), - keys(const_idxs), keys(nonnumeric_idxs), - observed_syms, independent_variable_symbols(sys))) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - symbol_to_variable[getname(sym)] = sym - end - end - - dependent_pars = Set{Union{BasicSymbolic, CallWithMetadata}}() + dependent_pars_to_timeseries = Dict{ + Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}() for eq in parameter_dependencies(sys) sym = eq.lhs + vs = vars(eq.rhs) + timeseries = TimeseriesSetType() + for v in vs + if (idx = get(disc_idxs, v, nothing)) !== nothing + push!(timeseries, idx.clock_idx) + end + end ttsym = default_toterm(sym) rsym = renamespace(sys, sym) rttsym = renamespace(sys, ttsym) - for s in [sym, ttsym, rsym, rttsym] - push!(dependent_pars, s) + for s in (sym, ttsym, rsym, rttsym) + dependent_pars_to_timeseries[s] = timeseries if hasname(s) && (!iscall(s) || operation(s) != getindex) symbol_to_variable[getname(s)] = sym end end end + observed_syms_to_timeseries = Dict{BasicSymbolic, TimeseriesSetType}() + for eq in observed(sys) + if symbolic_type(eq.lhs) != NotSymbolic() + sym = eq.lhs + vs = vars(eq.rhs) + timeseries = TimeseriesSetType() + for v in vs + if (idx = get(disc_idxs, v, nothing)) !== nothing + push!(timeseries, idx.clock_idx) + elseif haskey(unk_idxs, v) || haskey(observed_syms_to_timeseries, v) + push!(timeseries, ContinuousTimeseries()) + end + end + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + for s in (sym, ttsym, rsym, rttsym) + observed_syms_to_timeseries[s] = timeseries + end + end + end + + for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs), + keys(const_idxs), keys(nonnumeric_idxs), + keys(observed_syms_to_timeseries), independent_variable_symbols(sys))) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) + symbol_to_variable[getname(sym)] = sym + end + end + return IndexCache( unk_idxs, disc_idxs, @@ -297,8 +315,8 @@ function IndexCache(sys::AbstractSystem) tunable_idxs, const_idxs, nonnumeric_idxs, - observed_syms, - dependent_pars, + observed_syms_to_timeseries, + dependent_pars_to_timeseries, disc_buffer_templates, BufferTemplate(Real, tunable_buffer_size), const_buffer_sizes, diff --git a/test/odesystem.jl b/test/odesystem.jl index d7a2658f63..fb185fdcb0 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1504,3 +1504,14 @@ end sys2 = complete(sys; split = false) @test ModelingToolkit.get_index_cache(sys2) === nothing end + +# https://github.com/SciML/SciMLBase.jl/issues/786 +@testset "Observed variables dependent on discrete parameters" begin + @variables x(t) obs(t) + @parameters c(t) + @mtkbuild sys = ODESystem( + [D(x) ~ c * cos(x), obs ~ c], t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]]) + prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0]) + sol = solve(prob, Tsit5()) + @test sol[obs] ≈ 1:7 +end From bcd0edb1de1790c63200cd6f4229ebf0c183fdac Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 28 Oct 2024 18:04:21 +0530 Subject: [PATCH 2/4] fix: fix timeseries detection for shifted variables in DDEs --- src/systems/abstractsystem.jl | 4 ++++ src/systems/index_cache.jl | 27 +++++++++++++++++++-------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 15fc54e9e9..db2fcb4f10 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -766,6 +766,10 @@ for traitT in [ push!(ts_idxs, ContinuousTimeseries()) elseif is_timeseries_parameter(sys, s) push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx) + elseif is_time_dependent(sys) && iscall(s) && issym(operation(s)) && + is_variable(sys, operation(s)(get_iv(sys))) + # DDEs case, to detect x(t - k) + push!(ts_idxs, ContinuousTimeseries()) elseif has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing if (ts = get(ic.observed_syms_to_timeseries, s, nothing)) !== nothing union!(ts_idxs, ts) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 7a376aa1f2..307ac71a56 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -262,9 +262,11 @@ function IndexCache(sys::AbstractSystem) sym = eq.lhs vs = vars(eq.rhs) timeseries = TimeseriesSetType() - for v in vs - if (idx = get(disc_idxs, v, nothing)) !== nothing - push!(timeseries, idx.clock_idx) + if is_time_dependent(sys) + for v in vs + if (idx = get(disc_idxs, v, nothing)) !== nothing + push!(timeseries, idx.clock_idx) + end end end ttsym = default_toterm(sym) @@ -284,11 +286,20 @@ function IndexCache(sys::AbstractSystem) sym = eq.lhs vs = vars(eq.rhs) timeseries = TimeseriesSetType() - for v in vs - if (idx = get(disc_idxs, v, nothing)) !== nothing - push!(timeseries, idx.clock_idx) - elseif haskey(unk_idxs, v) || haskey(observed_syms_to_timeseries, v) - push!(timeseries, ContinuousTimeseries()) + if is_time_dependent(sys) + for v in vs + if (idx = get(disc_idxs, v, nothing)) !== nothing + push!(timeseries, idx.clock_idx) + elseif haskey(unk_idxs, v) + push!(timeseries, ContinuousTimeseries()) + elseif haskey(observed_syms_to_timeseries, v) + union!(timeseries, observed_syms_to_timeseries[v]) + elseif haskey(dependent_pars_to_timeseries, v) + union!(timeseries, dependent_pars_to_timeseries[v]) + elseif iscall(v) && issym(operation(v)) && + is_variable(sys, operation(v)(get_iv(sys))) + push!(timeseries, ContinuousTimeseries()) + end end end ttsym = default_toterm(sym) From 87c16598cf244f55c5ea4d483ae6ec553ef7da87 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 29 Oct 2024 15:36:54 +0530 Subject: [PATCH 3/4] fix: handle timeseries detection for constant observed --- src/systems/index_cache.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 307ac71a56..1e3490ee69 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -284,23 +284,21 @@ function IndexCache(sys::AbstractSystem) for eq in observed(sys) if symbolic_type(eq.lhs) != NotSymbolic() sym = eq.lhs - vs = vars(eq.rhs) + vs = vars(eq.rhs; op = Nothing) timeseries = TimeseriesSetType() if is_time_dependent(sys) for v in vs if (idx = get(disc_idxs, v, nothing)) !== nothing push!(timeseries, idx.clock_idx) - elseif haskey(unk_idxs, v) - push!(timeseries, ContinuousTimeseries()) elseif haskey(observed_syms_to_timeseries, v) union!(timeseries, observed_syms_to_timeseries[v]) elseif haskey(dependent_pars_to_timeseries, v) union!(timeseries, dependent_pars_to_timeseries[v]) - elseif iscall(v) && issym(operation(v)) && - is_variable(sys, operation(v)(get_iv(sys))) - push!(timeseries, ContinuousTimeseries()) end end + if isempty(timeseries) + push!(timeseries, ContinuousTimeseries()) + end end ttsym = default_toterm(sym) rsym = renamespace(sys, sym) From 55c52175830b91a83d70939e95941d359b70ed03 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 15 Nov 2024 15:46:53 +0530 Subject: [PATCH 4/4] ci: prevent codecov from failing CI --- .github/workflows/Downstream.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 7a7556efa3..a6d5e84656 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -72,4 +72,4 @@ jobs: with: file: lcov.info token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true + fail_ci_if_error: false