-
-
Notifications
You must be signed in to change notification settings - Fork 163
Description
Hi there
I'm working with dynamical systems and have run in a small issue regarding differentiation with respect to control inputs.
From my studies of the examples provided I assumed the main way to pass control inputs to a dynamical system is to pass in a LinearInterpolator object. I implemented this throughout my code, but now got stuck trying to calculate the sensitivities of my system with respect to this control input.
I have not found any examples of how to work with this kind of example. So I will present my attempts to work around this issue:
Lets take the toy dynamical system as example:
import diffrax as dfx
import jax.numpy as jnp
import jax
# Create linear interpolation
ts = jnp.linspace(0.0, 1.0, 10)
ys = jnp.block([[jnp.linspace(0.0, 1.0, 10)], [jnp.linspace(0.0, 1.0, 10)]]).T
u_interp = dfx.LinearInterpolation(ts=ts, ys=ys)
# Toy dynamical system
def f(t, y, args):
"""Dynamical system function"""
u = args
return jnp.array([[1, 0], [0, 1]]) @ y + jnp.array([[1, 0], [0, 1]]) @ u.evaluate(t)To calculate the sensitivities w.r.t. args:
dfdu = jax.jacfwd(f, argnums=2)(0, jnp.array([2.0,1.0]), u_interp)
print(dfdu)Then we get a LinearInterpolation object, this surprised me, as I thought it would just spit out an error:
LinearInterpolation(ts=f32[2,10], ys=f32[2,10,2])But now if we try to evaluate it at a given instant (maybe it should take another kind of input?????):
print(dfdu.evaluate(0))We get:
ValueError: a should be 1-dimensionalIt is not clear to me what value it refers to, but I believe this means that LinearInterpolation object has a f32 2-dimentional array as ts.
Inspecting the ts and ys values we get some pretty nonsensical values(to me at least).
print(dfdu.ts)[[-1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[-1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] print(dfdu.ys)[[[1. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]]
[[0. 1.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]]]Is this expected behavior?
Should I pursue other solutions?
Is there a more direct/standard way to calculate the sensitivity of my dynamical system?
Thank you very much for the attention and support. All assistance is welcomed.