-
-
Notifications
You must be signed in to change notification settings - Fork 21
Add interactive reparametrization #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,3 +32,5 @@ Manifest.toml | |
| .vscode/ | ||
|
|
||
| # End of https://www.toptal.com/developers/gitignore/api/julia | ||
|
|
||
| .ipynb_checkpoints | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -16,6 +16,13 @@ function vector_field_along(derivation::Dict{T, U}, directions::AbstractVector) | |||||
| return new_vector_field | ||||||
| end | ||||||
|
|
||||||
| function default_variable_names(new_states, new_params) | ||||||
| ( | ||||||
| states = map(i -> "X$i", 1:length(new_states)), | ||||||
| params = map(i -> "a$i", 1:length(new_params)), | ||||||
| ) | ||||||
| end | ||||||
|
|
||||||
| """ | ||||||
| reparametrize_with_respect_to(ode, new_states, new_params) | ||||||
|
|
||||||
|
|
@@ -27,7 +34,12 @@ Reparametrizes the `ode` using the given fractional states and parameters. | |||||
| - `new_states`: a vector of new states as fractions in `parent(ode)`. | ||||||
| - `new_params`: a vector of new parameters as fractions in `parent(ode)`. | ||||||
| """ | ||||||
| function reparametrize_with_respect_to(ode, new_states, new_params) | ||||||
| function reparametrize_with_respect_to( | ||||||
| ode::ODE{P}, | ||||||
| new_states, | ||||||
| new_params; | ||||||
| new_variable_names = default_variable_names(new_states, new_params), | ||||||
| ) where {P} | ||||||
| @assert length(new_states) > 0 | ||||||
| poly_ring = base_ring(parent(first(new_states))) | ||||||
| # Compute the new dynamics in terms of the original variables. | ||||||
|
|
@@ -52,7 +64,7 @@ function reparametrize_with_respect_to(ode, new_states, new_params) | |||||
| gen_tag_names(length(ode.u_vars), "Input"), | ||||||
| gen_tag_names(length(ode.y_vars), "Output"), | ||||||
| ) | ||||||
| @info """ | ||||||
| @debug """ | ||||||
| Tag names: | ||||||
| $tag_names | ||||||
| Generating functions: | ||||||
|
|
@@ -81,7 +93,7 @@ function reparametrize_with_respect_to(ode, new_states, new_params) | |||||
| if !isempty(ode.u_vars) | ||||||
| new_inputs = tag_inputs | ||||||
| end | ||||||
| @info """ | ||||||
| @debug """ | ||||||
| New state dynamics: | ||||||
| $new_dynamics_states | ||||||
| New output dynamics: | ||||||
|
|
@@ -94,14 +106,14 @@ function reparametrize_with_respect_to(ode, new_states, new_params) | |||||
| state = tags[i] | ||||||
| new_vars_vector_field[state] = new_dynamics_states[i] | ||||||
| end | ||||||
| @info "Converting variable names to human-readable ones" | ||||||
| internal_variable_names = map(i -> "X$i(t)", 1:length(new_states)) | ||||||
| parameter_variable_names = map(i -> "a$i", 1:length(new_params)) | ||||||
| @debug "Converting variable names to human-readable ones" | ||||||
| @assert length(new_variable_names.states) == length(new_states) | ||||||
| @assert length(new_variable_names.params) == length(new_params) | ||||||
| input_variable_names = map(i -> "u$i(t)", 1:length(tag_inputs)) | ||||||
| output_variable_names = map(i -> "y$i(t)", 1:length(tag_outputs)) | ||||||
| all_variable_names = vcat( | ||||||
| internal_variable_names, | ||||||
| parameter_variable_names, | ||||||
| new_variable_names.states, | ||||||
| new_variable_names.params, | ||||||
| input_variable_names, | ||||||
| output_variable_names, | ||||||
| ) | ||||||
|
|
@@ -237,3 +249,214 @@ function _reparametrize_global(ode::ODE{P}; prob_threshold = 0.99, seed = 42) wh | |||||
| new_ode = ODE{P}(new_vector_field, new_outputs, new_inputs) | ||||||
| return (new_ode = new_ode, new_vars = new_vars, implicit_relations = implicit_relations) | ||||||
| end | ||||||
|
|
||||||
| function reparametrize_interactive( | ||||||
| ode::ODE{P}; | ||||||
| prob_threshold = 0.99, | ||||||
| seed = 42, | ||||||
| loglevel = Logging.Info, | ||||||
| input::IO = stdin, | ||||||
| output::IO = stdout, | ||||||
| ) where {P} | ||||||
| restart_logging(loglevel = loglevel, stream = output) | ||||||
| with_logger(_si_logger[]) do | ||||||
| return _reparametrize_interactive(ode, prob_threshold, seed, input, output) | ||||||
| end | ||||||
| end | ||||||
|
|
||||||
| function _reparametrize_interactive( | ||||||
| ode::ODE{P}, | ||||||
| prob_threshold, | ||||||
| seed, | ||||||
| input, | ||||||
| output, | ||||||
| ) where {P} | ||||||
| Random.seed!(seed) | ||||||
| id_funcs = find_identifiable_functions( | ||||||
| ode, | ||||||
| with_states = true, | ||||||
| simplify = :strong, | ||||||
| prob_threshold = prob_threshold, | ||||||
| return_all = true, | ||||||
| loglevel = Logging.Warn, | ||||||
|
||||||
| ) | ||||||
| id_funcs_simple = find_identifiable_functions( | ||||||
| ode, | ||||||
| with_states = true, | ||||||
| simplify = :strong, | ||||||
| prob_threshold = prob_threshold, | ||||||
| loglevel = Logging.Warn, | ||||||
| ) | ||||||
| state = Dict( | ||||||
| :counter => 0, | ||||||
| :id_funcs => id_funcs, | ||||||
| :chosen_funcs => (states = empty(id_funcs), params = empty(id_funcs)), | ||||||
| :variable_names => (states = Vector{String}(), params = Vector{String}()), | ||||||
| ) | ||||||
| contains_states(poly::MPolyRingElem) = any(x -> degree(poly, x) > 0, ode.x_vars) | ||||||
| contains_states(func) = | ||||||
| contains_states(numerator(func)) || contains_states(denominator(func)) | ||||||
| function print_header(state) | ||||||
| println(output, "\n$(state[:counter]). Info.") | ||||||
| println(output, " Original states: $(join(string.(ode.x_vars), ", "))") | ||||||
| println(output, " Original parameters: $(join(string.(ode.parameters), ", "))") | ||||||
| println(output, " Identifiable functions: $(join(string.(id_funcs_simple), ", "))") | ||||||
| end | ||||||
| function print_state(state) | ||||||
| counter, chosen_funcs, variable_names = | ||||||
| state[:counter], state[:chosen_funcs], state[:variable_names] | ||||||
| if isempty(chosen_funcs.states) && isempty(chosen_funcs.params) | ||||||
| return | ||||||
| end | ||||||
| println(output, "\n$counter. Current selection:") | ||||||
| for (names, funcs) in [ | ||||||
| (variable_names.states, chosen_funcs.states), | ||||||
| (variable_names.params, chosen_funcs.params), | ||||||
| ] | ||||||
| for (name, func) in zip(names, funcs) | ||||||
| println(output, " ", name, " := ", func) | ||||||
| end | ||||||
| end | ||||||
| end | ||||||
| function make_choice(state) | ||||||
| counter, id_funcs = state[:counter], state[:id_funcs] | ||||||
| terminal = | ||||||
| REPL.TerminalMenus.default_terminal(in = input, out = output, err = output) | ||||||
| menu = MultiSelectMenu(vcat("Enter a custom function", map(string, id_funcs))) | ||||||
| choice = request( | ||||||
| terminal, | ||||||
| "\n$counter. Select identifiable function(s) for reparametrization.", | ||||||
| menu, | ||||||
| ) | ||||||
| if 1 in choice # a custom function | ||||||
| funcs = empty(id_funcs) | ||||||
| while true | ||||||
| varnames = map( | ||||||
| f -> chopsuffix(f, "(t)"), | ||||||
| string.(vcat(ode.x_vars, ode.parameters)), | ||||||
| ) | ||||||
| res = Base.prompt( | ||||||
| input, | ||||||
| output, | ||||||
| "\n$counter. Enter a rational function in the variables: $(join(varnames, ", "))\n", | ||||||
| ) | ||||||
| func = nothing | ||||||
| try | ||||||
| func = myeval( | ||||||
| Meta.parse(res), | ||||||
| Dict(Symbol.(varnames) .=> vcat(ode.x_vars, ode.parameters)), | ||||||
| ) | ||||||
| catch e | ||||||
| @info "" e | ||||||
| printstyled( | ||||||
| output, | ||||||
| "\n ==> Error when parsing $res. Trying again..\n", | ||||||
| bold = true, | ||||||
| ) | ||||||
| continue | ||||||
| end | ||||||
| ffring = fraction_field(parent(ode)) | ||||||
| funcs = [ffring(func)] | ||||||
| if all( | ||||||
| field_contains( | ||||||
| RationalFunctionField(id_funcs_simple), | ||||||
| funcs, | ||||||
| prob_threshold, | ||||||
| ), | ||||||
| ) | ||||||
| break | ||||||
| else | ||||||
| printstyled( | ||||||
| output, | ||||||
| "\n ==> The given function $(funcs[1]) is not identifiable. Trying again..\n", | ||||||
| bold = true, | ||||||
| ) | ||||||
| continue | ||||||
| end | ||||||
| end | ||||||
| else | ||||||
| funcs = id_funcs[sort(collect(choice)) .- 1] | ||||||
| end | ||||||
| funcs | ||||||
| end | ||||||
| function query_names(state, funcs) | ||||||
| counter, chosen_funcs, variable_names = | ||||||
| state[:counter], state[:chosen_funcs], state[:variable_names] | ||||||
| new_states = filter(contains_states, funcs) | ||||||
| new_params = setdiff(funcs, new_states) | ||||||
| idx_states, idx_params = length(chosen_funcs.states), length(chosen_funcs.params) | ||||||
| append!(chosen_funcs.states, new_states) | ||||||
| append!(chosen_funcs.params, new_params) | ||||||
| default_names = default_variable_names(chosen_funcs.states, chosen_funcs.params) | ||||||
| for (kind, vars, defaults, new_vars, idx) in [ | ||||||
| ("state", new_states, default_names.states, variable_names.states, idx_states), | ||||||
| ( | ||||||
| "parameter", | ||||||
| new_params, | ||||||
| default_names.params, | ||||||
| variable_names.params, | ||||||
| idx_params, | ||||||
| ), | ||||||
| ] | ||||||
| for (i, var) in enumerate(vars) | ||||||
| default = defaults[idx + i] | ||||||
| res = Base.prompt( | ||||||
| input, | ||||||
| output, | ||||||
| "\n$counter. Enter a name for the new $kind: $var. Leave empty for default: $default.\n", | ||||||
| ) | ||||||
| if isnothing(res) || (res isa String && isempty(strip(res))) | ||||||
| res = default | ||||||
| end | ||||||
| push!(new_vars, res) | ||||||
| printstyled(output, " ==> New variable: $res := $var\n", bold = true) | ||||||
| end | ||||||
| end | ||||||
| end | ||||||
| function try_to_reparametrize(state) | ||||||
| counter, chosen_funcs, variable_names, id_funcs = | ||||||
| state[:counter], state[:chosen_funcs], state[:variable_names], state[:id_funcs] | ||||||
| rff = RationalFunctionField(vcat(chosen_funcs.states, chosen_funcs.params)) | ||||||
| state[:id_funcs] = id_funcs[.! field_contains(rff, id_funcs, prob_threshold)] | ||||||
| if isempty(chosen_funcs.states) | ||||||
| printstyled( | ||||||
| output, | ||||||
| "\n ==> Please select at least one new state in order to reparametrize.\n", | ||||||
| bold = true, | ||||||
| ) | ||||||
| return nothing | ||||||
| end | ||||||
| try | ||||||
| new_vector_field, new_inputs, new_outputs, new_vars, implicit_relations = | ||||||
| reparametrize_with_respect_to( | ||||||
| ode, | ||||||
| chosen_funcs.states, | ||||||
| chosen_funcs.params, | ||||||
| new_variable_names = variable_names, | ||||||
| ) | ||||||
| new_ode = ODE{P}(new_vector_field, new_outputs, new_inputs) | ||||||
| return ( | ||||||
| new_ode = new_ode, | ||||||
| new_vars = new_vars, | ||||||
| implicit_relations = implicit_relations, | ||||||
| ) | ||||||
| catch e | ||||||
| printstyled( | ||||||
| output, | ||||||
| "\n ==> Selected functions is not enough to reparametrize. Please select more.\n", | ||||||
|
||||||
| "\n ==> Selected functions is not enough to reparametrize. Please select more.\n", | |
| "\n ==> Selected functions are not enough to reparametrize. Please select more.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent indentation: line 281 uses tabs while the rest of the file uses spaces. This should be changed to spaces for consistency.