From b40c2a1dbda36d481067ffa5450599540cefdc40 Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 24 Feb 2025 11:58:05 -0500 Subject: [PATCH 1/2] fix Complex typecheck --- src/utils.jl | 16 ++++++++++++---- test/odesystem.jl | 4 ++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 962801622a..4685e76ebc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -222,7 +222,8 @@ function collect_ivs_from_nested_operator!(ivs, x, target_op) end function iv_from_nested_derivative(x, op = Differential) - if iscall(x) && operation(x) == getindex + if iscall(x) && + (operation(x) == getindex || operation(x) == real || operation(x) == imag) iv_from_nested_derivative(arguments(x)[1], op) elseif iscall(x) operation(x) isa op ? iv_from_nested_derivative(arguments(x)[1], op) : @@ -1204,7 +1205,7 @@ end Find all the unknowns and parameters from the equations of a SDESystem or ODESystem. Return re-ordered equations, differential variables, all variables, and parameters. """ function process_equations(eqs, iv) - eqs = collect(eqs) + eqs = collect(Iterators.flatten(eqs)) diffvars = OrderedSet() allunknowns = OrderedSet() @@ -1237,8 +1238,8 @@ function process_equations(eqs, iv) throw(ArgumentError("An ODESystem can only have one independent variable.")) diffvar in diffvars && throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations.")) - !(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) && - throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should not be concretely typed.")) + !has_diffvar_type(diffvar) && + throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should be of a continuous, non-concrete number type: Real, Complex, AbstractFloat, or Number.")) push!(diffvars, diffvar) end push!(diffeq, eq) @@ -1250,6 +1251,13 @@ function process_equations(eqs, iv) diffvars, allunknowns, ps, Equation[diffeq; algeeq; compressed_eqs] end +function has_diffvar_type(diffvar) + st = symtype(diffvar) + st === Real || eltype(st) === Real || st === Complex || eltype(st) === Complex || + st === Number || eltype(st) === Number || st === AbstractFloat || + eltype(st) === AbstractFloat +end + """ $(TYPEDSIGNATURES) diff --git a/test/odesystem.jl b/test/odesystem.jl index 01df23ede8..b9a41e1c3d 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1589,6 +1589,10 @@ end @variables Y(t)[1:3]::String eq = D(Y) ~ [p, p, p] @test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t) + + @variables X(t)::Complex + eq = D(X) ~ p - d * X + @test_nowarn @named osys = ODESystem([eq], t) end # Test `isequal` From 956079768e8a8c1785440a8b370b3acb5b1fb756 Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 24 Feb 2025 12:27:10 -0500 Subject: [PATCH 2/2] up --- src/utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 4685e76ebc..950af5a1e2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1205,7 +1205,8 @@ end Find all the unknowns and parameters from the equations of a SDESystem or ODESystem. Return re-ordered equations, differential variables, all variables, and parameters. """ function process_equations(eqs, iv) - eqs = collect(Iterators.flatten(eqs)) + eltype(eqs) <: Vector && (eqs = vcat(eqs...)) + eqs = collect(eqs) diffvars = OrderedSet() allunknowns = OrderedSet()