From ab023f6e449241e21782198256f32fe671ea3509 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 8 May 2025 11:37:53 +0530 Subject: [PATCH 1/2] fix: fix breaking change to `generate_control_function` --- docs/src/basics/InputOutput.md | 2 +- docs/src/tutorials/disturbance_modeling.md | 6 +++--- src/inputoutput.jl | 13 +++++++------ src/systems/optimal_control_interface.jl | 1 + test/downstream/test_disturbance_model.jl | 8 ++++---- test/extensions/test_infiniteopt.jl | 2 +- test/input_output_handling.jl | 10 +++++----- 7 files changed, 22 insertions(+), 20 deletions(-) diff --git a/docs/src/basics/InputOutput.md b/docs/src/basics/InputOutput.md index b01d093980..4dc5a3d50f 100644 --- a/docs/src/basics/InputOutput.md +++ b/docs/src/basics/InputOutput.md @@ -70,7 +70,7 @@ Now we can test the generated function `f` with random input and state values p = [1] x = [rand()] u = [rand()] -@test f(x, u, p, 1) ≈ -p[] * (x + u) # Test that the function computes what we expect D(x) = -k*(x + u) +@test f[1](x, u, p, 1) ≈ -p[] * (x + u) # Test that the function computes what we expect D(x) = -k*(x + u) ``` ## Generating an output function, ``g`` diff --git a/docs/src/tutorials/disturbance_modeling.md b/docs/src/tutorials/disturbance_modeling.md index 0d85299744..db8d926498 100644 --- a/docs/src/tutorials/disturbance_modeling.md +++ b/docs/src/tutorials/disturbance_modeling.md @@ -184,7 +184,7 @@ disturbance_inputs = [ssys.d1, ssys.d2] P = ssys.system_model outputs = [P.inertia1.phi, P.inertia2.phi, P.inertia1.w, P.inertia2.w] -f, x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function( +(f_oop, f_ip), x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function( model_with_disturbance, inputs, disturbance_inputs; disturbance_argument = true) g = ModelingToolkit.build_explicit_observed_function( @@ -195,12 +195,12 @@ x0, _ = ModelingToolkit.get_u0_p(io_sys, op, op) p = MTKParameters(io_sys, op) u = zeros(1) # Control input w = zeros(length(disturbance_inputs)) # Disturbance input -@test f(x0, u, p, t, w) == zeros(5) +@test f_oop(x0, u, p, t, w) == zeros(5) @test g(x0, u, p, 0.0) == [0, 0, 0, 0] # Non-zero disturbance inputs should result in non-zero state derivatives. We call `sort` since we do not generally know the order of the state variables w = [1.0, 2.0] -@test sort(f(x0, u, p, t, w)) == [0, 0, 0, 1, 2] +@test sort(f_oop(x0, u, p, t, w)) == [0, 0, 0, 1, 2] ``` ## Input signal library diff --git a/src/inputoutput.jl b/src/inputoutput.jl index beaff42a0b..739df14ae0 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -160,7 +160,7 @@ has_var(ex, x) = x ∈ Set(get_variables(ex)) # Build control function """ - f, x_sym, p_sym, io_sys = generate_control_function( + (f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function( sys::AbstractODESystem, inputs = unbound_inputs(sys), disturbance_inputs = nothing; @@ -168,9 +168,9 @@ has_var(ex, x) = x ∈ Set(get_variables(ex)) simplify = false, ) -For a system `sys` with inputs (as determined by [`unbound_inputs`](@ref) or user specified), generate a function with additional input argument `u` +For a system `sys` with inputs (as determined by [`unbound_inputs`](@ref) or user specified), generate functions with additional input argument `u` -The returned function `f` can be called in the out-of-place or in-place form: +The returned functions are the out-of-place (`f_oop`) and in-place (`f_ip`) forms: ``` f_oop : (x,u,p,t) -> rhs f_ip : (xout,x,u,p,t) -> nothing @@ -191,7 +191,7 @@ f, x_sym, ps = generate_control_function(sys, expression=Val{false}, simplify=fa p = varmap_to_vars(defaults(sys), ps) x = varmap_to_vars(defaults(sys), x_sym) t = 0 -f(x, inputs, p, t) +f[1](x, inputs, p, t) ``` """ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys), @@ -253,9 +253,10 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae, p_end = length(p) + 2 + implicit_dae, kwargs...) f = eval_or_rgf.(f; eval_expression, eval_module) - f = GeneratedFunctionWrapper{(3, length(args) - length(p) + 1, is_split(sys))}(f...) + f = GeneratedFunctionWrapper{( + 3 + implicit_dae, length(args) - length(p) + 1, is_split(sys))}(f...) ps = setdiff(parameters(sys), inputs, disturbance_inputs) - (; f, dvs, ps, io_sys = sys) + (; f = (f, f), dvs, ps, io_sys = sys) end function inputs_to_parameters!(state::TransformationState, io) diff --git a/src/systems/optimal_control_interface.jl b/src/systems/optimal_control_interface.jl index beccbf8b41..eb573da810 100644 --- a/src/systems/optimal_control_interface.jl +++ b/src/systems/optimal_control_interface.jl @@ -52,6 +52,7 @@ function SciMLBase.ODEInputFunction{iip, specialize}(sys::ODESystem, kwargs...) where {iip, specialize} f, _, _ = generate_control_function( sys, inputs, disturbance_inputs; eval_module, cse, kwargs...) + f = f[1] if tgrad tgrad_gen = generate_tgrad(sys, dvs, ps; diff --git a/test/downstream/test_disturbance_model.jl b/test/downstream/test_disturbance_model.jl index cd9d769e99..0e04200237 100644 --- a/test/downstream/test_disturbance_model.jl +++ b/test/downstream/test_disturbance_model.jl @@ -168,22 +168,22 @@ x0, p = ModelingToolkit.get_u0_p(io_sys, op, op) x = zeros(5) u = zeros(1) d = zeros(3) -@test f(x, u, p, t, d) == zeros(5) +@test f[1](x, u, p, t, d) == zeros(5) @test measurement(x, u, p, 0.0) == [0, 0, 0, 0] @test measurement2(x, u, p, 0.0, d) == [0] # Add to the integrating disturbance input d = [1, 0, 0] -@test sort(f(x, u, p, 0.0, d)) == [0, 0, 0, 1, 1] # Affects disturbance state and one velocity +@test sort(f[1](x, u, p, 0.0, d)) == [0, 0, 0, 1, 1] # Affects disturbance state and one velocity @test measurement2(x, u, p, 0.0, d) == [0] d = [0, 1, 0] -@test sort(f(x, u, p, 0.0, d)) == [0, 0, 0, 0, 1] # Affects one velocity +@test sort(f[1](x, u, p, 0.0, d)) == [0, 0, 0, 0, 1] # Affects one velocity @test measurement(x, u, p, 0.0) == [0, 0, 0, 0] @test measurement2(x, u, p, 0.0, d) == [0] d = [0, 0, 1] -@test sort(f(x, u, p, 0.0, d)) == [0, 0, 0, 0, 0] # Affects nothing +@test sort(f[1](x, u, p, 0.0, d)) == [0, 0, 0, 0, 0] # Affects nothing @test measurement(x, u, p, 0.0) == [0, 0, 0, 0] @test measurement2(x, u, p, 0.0, d) == [1] # We have now disturbed the output diff --git a/test/extensions/test_infiniteopt.jl b/test/extensions/test_infiniteopt.jl index 833b9f3275..eb734358c5 100644 --- a/test/extensions/test_infiniteopt.jl +++ b/test/extensions/test_infiniteopt.jl @@ -65,7 +65,7 @@ InfiniteOpt.@variables(m, # Trace the dynamics x0, p = ModelingToolkit.get_u0_p(io_sys, [model.θ => 0, model.ω => 0], [model.L => L]) -xp = f(x, u, p, τ) +xp = f[1](x, u, p, τ) cp = f_obs(x, u, p, τ) # Test that it's possible to trace through an observed function @objective(m, Min, tf) diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index 2c16c89320..115426444e 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -173,7 +173,7 @@ end p = [rand()] x = [rand()] u = [rand()] - @test f(x, u, p, 1) ≈ -x + u + @test f[1](x, u, p, 1) ≈ -x + u # With disturbance inputs @variables x(t)=0 u(t)=0 [input = true] d(t)=0 @@ -191,7 +191,7 @@ end p = [rand()] x = [rand()] u = [rand()] - @test f(x, u, p, 1) ≈ -x + u + @test f[1](x, u, p, 1) ≈ -x + u ## With added d argument @variables x(t)=0 u(t)=0 [input = true] d(t)=0 @@ -210,7 +210,7 @@ end x = [rand()] u = [rand()] d = [rand()] - @test f(x, u, p, t, d) ≈ -x + u + [d[]^2] + @test f[1](x, u, p, t, d) ≈ -x + u + [d[]^2] end end @@ -273,7 +273,7 @@ x = ModelingToolkit.varmap_to_vars( merge(ModelingToolkit.defaults(model), Dict(D.(unknowns(model)) .=> 0.0)), dvs) u = [rand()] -out = f(x, u, p, 1) +out = f[1](x, u, p, 1) i = findfirst(isequal(u[1]), out) @test i isa Int @test iszero(out[[1:(i - 1); (i + 1):end]]) @@ -447,7 +447,7 @@ end @named sys = ODESystem(eqs, t, [x], []) f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true) - @test f([0.5], nothing, MTKParameters(io_sys, []), 0.0) ≈ [1.0] + @test f[1]([0.5], nothing, MTKParameters(io_sys, []), 0.0) ≈ [1.0] end @testset "With callable symbolic" begin From 948f38ca71297e6e1f4f7c0cea2427df9cdc1c79 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 8 May 2025 15:00:08 +0530 Subject: [PATCH 2/2] test: fix input output tests --- test/input_output_handling.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index 115426444e..67f60a04cd 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -348,8 +348,8 @@ x0 = randn(5) x1 = copy(x0) + x_add # add disturbance state perturbation u = randn(1) pn = MTKParameters(io_sys, []) -xp0 = f(x0, u, pn, 0) -xp1 = f(x1, u, pn, 0) +xp0 = f[1](x0, u, pn, 0) +xp1 = f[1](x1, u, pn, 0) @test xp0 ≈ matrices.A * x0 + matrices.B * [u; 0] @test xp1 ≈ matrices.A * x1 + matrices.B * [u; 0] @@ -459,5 +459,5 @@ end p = MTKParameters(io_sys, []) u = [1.0] x = [1.0] - @test_nowarn f(x, u, p, 0.0) + @test_nowarn f[1](x, u, p, 0.0) end