diff --git a/.gitignore b/.gitignore index 759fe810..285f6a72 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,5 @@ Manifest.toml .vscode/ # End of https://www.toptal.com/developers/gitignore/api/julia + +.ipynb_checkpoints diff --git a/Project.toml b/Project.toml index 2cee6434..5b52bdbd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "StructuralIdentifiability" uuid = "220ca800-aa68-49bb-acd8-6037fa93a544" -authors = ["Alexander Demin, Ruiwen Dong, Christian Goodbrake, Heather Harrington, Gleb Pogudin "] version = "0.5.17" +authors = ["Alexander Demin, Ruiwen Dong, Christian Goodbrake, Heather Harrington, Gleb Pogudin "] [deps] AbstractAlgebra = "c3fe647b-3220-5bb0-a1ea-a7954cac585d" @@ -17,6 +17,7 @@ Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a" ParamPunPam = "3e851597-e36f-45a9-af0a-b7781937992f" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" +REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RationalFunctionFields = "73480bc8-48a2-41cc-880f-208b490ccf65" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" @@ -47,6 +48,7 @@ ParamPunPam = "0.5.5" Pkg = "1.10, 1.11" PrecompileTools = "1.2" Primes = "0.5" +REPL = "1.10, 1.11" Random = "1.10, 1.11" RationalFunctionFields = "0.2.2" SpecialFunctions = "2" diff --git a/src/StructuralIdentifiability.jl b/src/StructuralIdentifiability.jl index 00df16eb..5db05df6 100644 --- a/src/StructuralIdentifiability.jl +++ b/src/StructuralIdentifiability.jl @@ -11,6 +11,7 @@ using MacroTools using Primes using Random using TimerOutputs +using REPL, REPL.TerminalMenus # Algebra packages using AbstractAlgebra @@ -60,7 +61,7 @@ export PBRepresentation, diffreduce, io_switch!, pseudodivision export find_submodels # finding identifiabile reparametrizations -export reparametrize_global +export reparametrize_global, reparametrize_interactive ExtendedFraction{P} = Union{P, Generic.FracFieldElem{P}} diff --git a/src/identifiable_functions.jl b/src/identifiable_functions.jl index 0ef38f18..d4544d9f 100644 --- a/src/identifiable_functions.jl +++ b/src/identifiable_functions.jl @@ -55,6 +55,7 @@ function find_identifiable_functions( simplify = :standard, rational_interpolator = :VanDerHoevenLecerf, loglevel = Logging.Info, + return_all = false, ) where {T <: MPolyRingElem{QQFieldElem}} restart_logging(loglevel = loglevel) reset_timings() @@ -67,6 +68,7 @@ function find_identifiable_functions( with_states = with_states, simplify = simplify, rational_interpolator = rational_interpolator, + return_all = return_all, ) else id_funcs = _find_identifiable_functions_kic( @@ -76,6 +78,7 @@ function find_identifiable_functions( seed = seed, simplify = simplify, rational_interpolator = rational_interpolator, + return_all = return_all, ) # renaming variables from `x(t)` to `x(0)` return replace_with_ic(ode, id_funcs) @@ -90,6 +93,7 @@ function _find_identifiable_functions( with_states = false, simplify = :standard, rational_interpolator = :VanDerHoevenLecerf, + return_all = false, ) where {T <: MPolyRingElem{QQFieldElem}} Random.seed!(seed) @assert simplify in (:standard, :weak, :strong, :absent) @@ -124,6 +128,7 @@ function _find_identifiable_functions( simplify = simplify, rational_interpolator = rational_interpolator, priority_variables = [parent_ring_change(p, bring) for p in ode.parameters], + return_all = return_all, ) else id_funcs_fracs = dennums_to_fractions(id_funcs) diff --git a/src/logging.jl b/src/logging.jl index b5b2d84d..3a4f27cb 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -30,9 +30,13 @@ const _runtime_logger = Dict( const _si_logger = Ref{Logging.ConsoleLogger}(Logging.ConsoleLogger(Logging.Info, show_limited = false)) -function restart_logging(; loglevel = Logging.Info) +function restart_logging(; loglevel = Logging.Info, stream = nothing) @assert loglevel isa Base.CoreLogging.LogLevel - _si_logger[] = Logging.ConsoleLogger(loglevel, show_limited = false) + if stream !== nothing + _si_logger[] = Logging.ConsoleLogger(stream, loglevel, show_limited = false) + else + _si_logger[] = Logging.ConsoleLogger(loglevel, show_limited = false) + end for r in _runtime_rubrics _runtime_logger[r] = 0 end diff --git a/src/parametrizations.jl b/src/parametrizations.jl index ccdd8286..a9d6caf9 100644 --- a/src/parametrizations.jl +++ b/src/parametrizations.jl @@ -16,6 +16,13 @@ function vector_field_along(derivation::Dict{T, U}, directions::AbstractVector) return new_vector_field end +function default_variable_names(new_states, new_params) + ( + states = map(i -> "X$i", 1:length(new_states)), + params = map(i -> "a$i", 1:length(new_params)), + ) +end + """ reparametrize_with_respect_to(ode, new_states, new_params) @@ -27,7 +34,12 @@ Reparametrizes the `ode` using the given fractional states and parameters. - `new_states`: a vector of new states as fractions in `parent(ode)`. - `new_params`: a vector of new parameters as fractions in `parent(ode)`. """ -function reparametrize_with_respect_to(ode, new_states, new_params) +function reparametrize_with_respect_to( + ode::ODE{P}, + new_states, + new_params; + new_variable_names = default_variable_names(new_states, new_params), +) where {P} @assert length(new_states) > 0 poly_ring = base_ring(parent(first(new_states))) # Compute the new dynamics in terms of the original variables. @@ -52,7 +64,7 @@ function reparametrize_with_respect_to(ode, new_states, new_params) gen_tag_names(length(ode.u_vars), "Input"), gen_tag_names(length(ode.y_vars), "Output"), ) - @info """ + @debug """ Tag names: $tag_names Generating functions: @@ -81,7 +93,7 @@ function reparametrize_with_respect_to(ode, new_states, new_params) if !isempty(ode.u_vars) new_inputs = tag_inputs end - @info """ + @debug """ New state dynamics: $new_dynamics_states New output dynamics: @@ -94,14 +106,14 @@ function reparametrize_with_respect_to(ode, new_states, new_params) state = tags[i] new_vars_vector_field[state] = new_dynamics_states[i] end - @info "Converting variable names to human-readable ones" - internal_variable_names = map(i -> "X$i(t)", 1:length(new_states)) - parameter_variable_names = map(i -> "a$i", 1:length(new_params)) + @debug "Converting variable names to human-readable ones" + @assert length(new_variable_names.states) == length(new_states) + @assert length(new_variable_names.params) == length(new_params) input_variable_names = map(i -> "u$i(t)", 1:length(tag_inputs)) output_variable_names = map(i -> "y$i(t)", 1:length(tag_outputs)) all_variable_names = vcat( - internal_variable_names, - parameter_variable_names, + new_variable_names.states, + new_variable_names.params, input_variable_names, output_variable_names, ) @@ -237,3 +249,214 @@ function _reparametrize_global(ode::ODE{P}; prob_threshold = 0.99, seed = 42) wh new_ode = ODE{P}(new_vector_field, new_outputs, new_inputs) return (new_ode = new_ode, new_vars = new_vars, implicit_relations = implicit_relations) end + +function reparametrize_interactive( + ode::ODE{P}; + prob_threshold = 0.99, + seed = 42, + loglevel = Logging.Info, + input::IO = stdin, + output::IO = stdout, +) where {P} + restart_logging(loglevel = loglevel, stream = output) + with_logger(_si_logger[]) do + return _reparametrize_interactive(ode, prob_threshold, seed, input, output) + end +end + +function _reparametrize_interactive( + ode::ODE{P}, + prob_threshold, + seed, + input, + output, +) where {P} + Random.seed!(seed) + id_funcs = find_identifiable_functions( + ode, + with_states = true, + simplify = :strong, + prob_threshold = prob_threshold, + return_all = true, + loglevel = Logging.Warn, + ) + id_funcs_simple = find_identifiable_functions( + ode, + with_states = true, + simplify = :strong, + prob_threshold = prob_threshold, + loglevel = Logging.Warn, + ) + state = Dict( + :counter => 0, + :id_funcs => id_funcs, + :chosen_funcs => (states = empty(id_funcs), params = empty(id_funcs)), + :variable_names => (states = Vector{String}(), params = Vector{String}()), + ) + contains_states(poly::MPolyRingElem) = any(x -> degree(poly, x) > 0, ode.x_vars) + contains_states(func) = + contains_states(numerator(func)) || contains_states(denominator(func)) + function print_header(state) + println(output, "\n$(state[:counter]). Info.") + println(output, " Original states: $(join(string.(ode.x_vars), ", "))") + println(output, " Original parameters: $(join(string.(ode.parameters), ", "))") + println(output, " Identifiable functions: $(join(string.(id_funcs_simple), ", "))") + end + function print_state(state) + counter, chosen_funcs, variable_names = + state[:counter], state[:chosen_funcs], state[:variable_names] + if isempty(chosen_funcs.states) && isempty(chosen_funcs.params) + return + end + println(output, "\n$counter. Current selection:") + for (names, funcs) in [ + (variable_names.states, chosen_funcs.states), + (variable_names.params, chosen_funcs.params), + ] + for (name, func) in zip(names, funcs) + println(output, " ", name, " := ", func) + end + end + end + function make_choice(state) + counter, id_funcs = state[:counter], state[:id_funcs] + terminal = + REPL.TerminalMenus.default_terminal(in = input, out = output, err = output) + menu = MultiSelectMenu(vcat("Enter a custom function", map(string, id_funcs))) + choice = request( + terminal, + "\n$counter. Select identifiable function(s) for reparametrization.", + menu, + ) + if 1 in choice # a custom function + funcs = empty(id_funcs) + while true + varnames = map( + f -> chopsuffix(f, "(t)"), + string.(vcat(ode.x_vars, ode.parameters)), + ) + res = Base.prompt( + input, + output, + "\n$counter. Enter a rational function in the variables: $(join(varnames, ", "))\n", + ) + func = nothing + try + func = myeval( + Meta.parse(res), + Dict(Symbol.(varnames) .=> vcat(ode.x_vars, ode.parameters)), + ) + catch e + @info "" e + printstyled( + output, + "\n ==> Error when parsing $res. Trying again..\n", + bold = true, + ) + continue + end + ffring = fraction_field(parent(ode)) + funcs = [ffring(func)] + if all( + field_contains( + RationalFunctionField(id_funcs_simple), + funcs, + prob_threshold, + ), + ) + break + else + printstyled( + output, + "\n ==> The given function $(funcs[1]) is not identifiable. Trying again..\n", + bold = true, + ) + continue + end + end + else + funcs = id_funcs[sort(collect(choice)) .- 1] + end + funcs + end + function query_names(state, funcs) + counter, chosen_funcs, variable_names = + state[:counter], state[:chosen_funcs], state[:variable_names] + new_states = filter(contains_states, funcs) + new_params = setdiff(funcs, new_states) + idx_states, idx_params = length(chosen_funcs.states), length(chosen_funcs.params) + append!(chosen_funcs.states, new_states) + append!(chosen_funcs.params, new_params) + default_names = default_variable_names(chosen_funcs.states, chosen_funcs.params) + for (kind, vars, defaults, new_vars, idx) in [ + ("state", new_states, default_names.states, variable_names.states, idx_states), + ( + "parameter", + new_params, + default_names.params, + variable_names.params, + idx_params, + ), + ] + for (i, var) in enumerate(vars) + default = defaults[idx + i] + res = Base.prompt( + input, + output, + "\n$counter. Enter a name for the new $kind: $var. Leave empty for default: $default.\n", + ) + if isnothing(res) || (res isa String && isempty(strip(res))) + res = default + end + push!(new_vars, res) + printstyled(output, " ==> New variable: $res := $var\n", bold = true) + end + end + end + function try_to_reparametrize(state) + counter, chosen_funcs, variable_names, id_funcs = + state[:counter], state[:chosen_funcs], state[:variable_names], state[:id_funcs] + rff = RationalFunctionField(vcat(chosen_funcs.states, chosen_funcs.params)) + state[:id_funcs] = id_funcs[.! field_contains(rff, id_funcs, prob_threshold)] + if isempty(chosen_funcs.states) + printstyled( + output, + "\n ==> Please select at least one new state in order to reparametrize.\n", + bold = true, + ) + return nothing + end + try + new_vector_field, new_inputs, new_outputs, new_vars, implicit_relations = + reparametrize_with_respect_to( + ode, + chosen_funcs.states, + chosen_funcs.params, + new_variable_names = variable_names, + ) + new_ode = ODE{P}(new_vector_field, new_outputs, new_inputs) + return ( + new_ode = new_ode, + new_vars = new_vars, + implicit_relations = implicit_relations, + ) + catch e + printstyled( + output, + "\n ==> Selected functions is not enough to reparametrize. Please select more.\n", + bold = true, + ) + end + return nothing + end + + print_header(state) + while true + state[:counter] += 1 + print_state(state) + funcs = make_choice(state) + query_names(state, funcs) + result = try_to_reparametrize(state) + !(result == nothing) && return result + end +end diff --git a/src/util.jl b/src/util.jl index eaae90b4..536c66eb 100644 --- a/src/util.jl +++ b/src/util.jl @@ -301,3 +301,130 @@ function replace_with_ic(ode, funcs) end # ----------------------------------------------------------------------------- +# Parsing polynomials +# +# Adapted from https://discourse.julialang.org/t/expression-parser/41880/7 +# code by Alan R. Rogers, Professor of Anthropology, University of Utah +# +# Adapted from https://github.com/x3042/ExactODEReduction.jl/blob/5539e2d81cd7a223b814ae7d3213f382fa650ab4/src/parser/myeval.jl +# code by Elizaveta Demitraki, Alexander Demin, Gleb Pogudin + +function myeval(e::Union{Expr, Symbol, Number}, map::Dict{Symbol, P}) where {P} + try + return _myeval(e, map) + catch ex + println("Can't parse \"$e\"") + rethrow(ex) + end +end + +function _myeval(s::Symbol, map::Dict{Symbol, P}) where {P} + if haskey(map, s) + return map[s] + else + @info "Can not find $s in $map while parsing.." + throw("$s") + end +end + +function _myeval(x::Number, map::Dict{Symbol, P}) where {P} + k = base_ring(first(values(map))) + k(x) +end + +# a helper definition for floats +function _myeval(x::Float64, map::Dict{Symbol, P}) where {P} + k = base_ring(first(values(map))) + result = k(0) + + # Getting the result from the string representation in order + # to avoid approximations caused by the float representation + s = string(x) + denom = 1 + extra_num = 1 + if occursin(r"[eE]", s) + s, exp = split(s, r"[eE]") + if exp[1] == "-" + denom = k(10)^(-parse(Int, exp)) + else + extra_num = k(10)^(parse(Int, exp)) + end + end + frac = split(s, ".") + if length(frac) == 1 + result = k(parse(fmpz, s)) * extra_num // denom + else + result = + k(parse(fmpz, frac[1] * frac[2])) * extra_num // (denom * 10^(length(frac[2]))) + end + + # too verbose for now + # @warn "a possibility of inexact float conversion" from=x to=result + return result +end + +# To parse an expression, convert the head to a singleton +# type, so that Julia can dispatch on that type. +function _myeval(e::Expr, map::Dict{Symbol, P}) where {P} + return _myeval(Val(e.head), e.args, map) +end + +# Call the function named in args[1] +function _myeval(::Val{:call}, args, map::Dict{Symbol, P}) where {P} + return _myeval(Val(args[1]), args[2:end], map) +end + +# Addition +function _myeval(::Val{:+}, args, map::Dict{Symbol, P}) where {P} + x = 0 + for arg in args + x += _myeval(arg, map) + end + return x +end + +# Subtraction and negation +function _myeval(::Val{:-}, args, map::Dict{Symbol, P}) where {P} + len = length(args) + if len == 1 + return -_myeval(args[1], map) + else + return _myeval(args[1], map) - _myeval(args[2], map) + end +end + +# Multiplication +function _myeval(::Val{:*}, args, map::Dict{Symbol, P}) where {P} + x = 1 + for arg in args + x *= _myeval(arg, map) + end + return x +end + +# Division +function _myeval(::Val{:/}, args, map::Dict{Symbol, P}) where {P} + # note // instead of / + return _myeval(args[1], map) // _myeval(args[2], map) +end + +function _myeval(::Val{://}, args, map::Dict{Symbol, P}) where {P} + return _myeval(args[1], map) // _myeval(args[2], map) +end + +# Exponentiation +function _myeval(::Val{:^}, args, map::Dict{Symbol, P}) where {P} + if typeof(_myeval(args[2], map)) <: P + @warn "We can not parse polynomial fractions, sorry" + throw(ParseException("Polynomial fractions are not supported")) + end + + if _myeval(args[2], map) < 0 + @warn "Negative exponent encountered while parsing" + throw(ParseException("Negative exponents are not supported")) + end + + return _myeval(args[1], map) ^ Int(numerator(_myeval(args[2], map))) +end + +# ----------------------------------------------------------------------------- diff --git a/test/interactive.jl b/test/interactive.jl new file mode 100644 index 00000000..d70147a0 --- /dev/null +++ b/test/interactive.jl @@ -0,0 +1,69 @@ +using Logging + +# Adapted from https://github.com/JuliaLang/julia/blob/5023ee21d70b734edf206aab3cac7c202ee0235a/stdlib/REPL/test/TerminalMenus/runtests.jl#L7 +# The licence is MIT. +function simulate_input(keys...; kwargs...) + keydict = Dict(:up => "\e[A", :down => "\e[B", :enter => "\r", :newline => "\n") + + new_stdin = Base.BufferStream() + for key in keys + if isa(key, Symbol) + write(new_stdin, keydict[key]) + else + write(new_stdin, "$key") + end + end + + return new_stdin +end + +@testset "Interactive reparametrizations" begin + ode = @ODEmodel(x'(t) = x*u(t), y(t) = x*a) + new_stdin = simulate_input(:down, :enter, "d", "X", :enter, :newline) + # Sasha: simulating input this way generates warnings in the Julia + # standard library (which we do not see because of `output = devnull`). + # We still use this method of testing, because it is used in the Julia + # standard library. + res = reparametrize_interactive( + ode, + input = new_stdin, + loglevel = Logging.Error, + output = devnull, + ) + @test res[1] isa ODE + @test string.(res[1].x_vars) == ["X"] + @test isempty(res[1].parameters) + old_var_to_new = Dict(v => k for (k, v) in res[2]) + @test old_var_to_new[x * a // 1] == res[1].x_vars[1] + + ode = @ODEmodel(x'(t) = x + b^2 + 1, y(t) = x) + new_stdin = simulate_input( + :enter, + "d", + "b", + :enter, + :newline, + "bb", + :enter, + :newline, + "b^2 + 1", + :enter, + :newline, + "B", + :enter, + :newline, + :down, + :enter, + "d", + :enter, + :newline, + ) + res = reparametrize_interactive( + ode, + input = new_stdin, + loglevel = Logging.Error, + output = devnull, + ) + @test string.(res[1].x_vars) == ["X1"] + @test string.(res[1].parameters) == ["B"] +end