diff --git a/src/systems/optimization/modelingtoolkitize.jl b/src/systems/optimization/modelingtoolkitize.jl index 1ceea795c5..b66f113f6c 100644 --- a/src/systems/optimization/modelingtoolkitize.jl +++ b/src/systems/optimization/modelingtoolkitize.jl @@ -33,6 +33,15 @@ function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem; end _vars = reshape(_vars, size(prob.u0)) vars = ArrayInterface.restructure(prob.u0, _vars) + if prob.ub !== nothing # lb is also !== nothing + vars = map(vars, prob.lb, prob.ub) do sym, lb, ub + if iszero(lb) && iszero(ub) || isinf(lb) && lb < 0 && isinf(ub) && ub > 0 + sym + else + Symbolics.setmetadata(sym, VariableBounds, (lb, ub)) + end + end + end params = if has_p if p_names === nothing && SciMLBase.has_sys(prob.f) p_names = Dict(parameter_index(prob.f.sys, sym) => sym diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 22488e8bad..271e0073ce 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -103,6 +103,17 @@ function OptimizationSystem(op, unknowns, ps; ps′ = value.(ps) op′ = value(scalarize(op)) + irreducible_subs = Dict() + for i in eachindex(unknowns′) + var = unknowns′[i] + if hasbounds(var) + irreducible_subs[var] = irrvar = setirreducible(var, true) + unknowns′[i] = irrvar + end + end + op′ = substitute(op′, irreducible_subs) + constraints = substitute.(constraints, (irreducible_subs,)) + if !(isempty(default_u0) && isempty(default_p)) Base.depwarn( "`default_u0` and `default_p` are deprecated. Use `defaults` instead.", @@ -113,7 +124,8 @@ function OptimizationSystem(op, unknowns, ps; throw(ArgumentError("System names must be unique.")) end defaults = todict(defaults) - defaults = Dict(value(k) => value(v) + defaults = Dict(substitute(value(k), irreducible_subs) => substitute( + value(v), irreducible_subs) for (k, v) in pairs(defaults) if value(v) !== nothing) var_to_name = Dict() diff --git a/src/variables.jl b/src/variables.jl index c0c875450c..fcfc7f9b1e 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -106,6 +106,7 @@ isoutput(x) = isvarkind(VariableOutput, x) # Before the solvability check, we already have handled IO variables, so # irreducibility is independent from IO. isirreducible(x) = isvarkind(VariableIrreducible, x) +setirreducible(x, v) = setmetadata(x, VariableIrreducible, v) state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64 function default_toterm(x) diff --git a/test/modelingtoolkitize.jl b/test/modelingtoolkitize.jl index 32a9720f47..621bf530b7 100644 --- a/test/modelingtoolkitize.jl +++ b/test/modelingtoolkitize.jl @@ -67,6 +67,16 @@ sol = solve(prob, BFGS()) sol = solve(prob, Newton()) @test sol.objective < 1e-8 +prob = OptimizationProblem(ones(3); lb = [-Inf, 0.0, 1.0], ub = [Inf, 0.0, 2.0]) do u, p + sum(abs2, u) +end + +sys = complete(modelingtoolkitize(prob)) +@test !ModelingToolkit.hasbounds(unknowns(sys)[1]) +@test !ModelingToolkit.hasbounds(unknowns(sys)[2]) +@test ModelingToolkit.hasbounds(unknowns(sys)[3]) +@test ModelingToolkit.getbounds(unknowns(sys)[3]) == (1.0, 2.0) + ## SIR System Regression Test β = 0.01# infection rate diff --git a/test/optimizationsystem.jl b/test/optimizationsystem.jl index 426d6d5de0..dfa11fca37 100644 --- a/test/optimizationsystem.jl +++ b/test/optimizationsystem.jl @@ -1,5 +1,5 @@ using ModelingToolkit, SparseArrays, Test, Optimization, OptimizationOptimJL, - OptimizationMOI, Ipopt, AmplNLWriter, Ipopt_jll + OptimizationMOI, Ipopt, AmplNLWriter, Ipopt_jll, SymbolicIndexingInterface using ModelingToolkit: get_metadata @testset "basic" begin @@ -347,3 +347,15 @@ end prob = @test_nowarn OptimizationProblem(sys, nothing) @test_nowarn solve(prob, NelderMead()) end + +@testset "Bounded unknowns are irreducible" begin + @variables x + @variables y [bounds = (-Inf, Inf)] + @variables z [bounds = (1.0, 2.0)] + obj = x^2 + y^2 + z^2 + cons = [y ~ 2x + z ~ 2y] + @mtkbuild sys = OptimizationSystem(obj, [x, y, z], []; constraints = cons) + @test is_variable(sys, z) + @test !is_variable(sys, y) +end