Skip to content

Diferentiating w.r.t. a LinearInterpolation #686

@JeanCMthomasio

Description

@JeanCMthomasio

Hi there

I'm working with dynamical systems and have run in a small issue regarding differentiation with respect to control inputs.

From my studies of the examples provided I assumed the main way to pass control inputs to a dynamical system is to pass in a LinearInterpolator object. I implemented this throughout my code, but now got stuck trying to calculate the sensitivities of my system with respect to this control input.

I have not found any examples of how to work with this kind of example. So I will present my attempts to work around this issue:

Lets take the toy dynamical system as example:

import diffrax as dfx
import jax.numpy as jnp
import jax

# Create linear interpolation
ts = jnp.linspace(0.0, 1.0, 10)
ys = jnp.block([[jnp.linspace(0.0, 1.0, 10)], [jnp.linspace(0.0, 1.0, 10)]]).T
u_interp = dfx.LinearInterpolation(ts=ts, ys=ys)

# Toy dynamical system
def f(t, y, args):
    """Dynamical system function"""
    u = args
    return jnp.array([[1, 0], [0, 1]]) @ y + jnp.array([[1, 0], [0, 1]]) @ u.evaluate(t)

To calculate the sensitivities w.r.t. args:

dfdu = jax.jacfwd(f, argnums=2)(0, jnp.array([2.0,1.0]), u_interp)
print(dfdu)

Then we get a LinearInterpolation object, this surprised me, as I thought it would just spit out an error:

LinearInterpolation(ts=f32[2,10], ys=f32[2,10,2])

But now if we try to evaluate it at a given instant (maybe it should take another kind of input?????):

print(dfdu.evaluate(0))

We get:

ValueError: a should be 1-dimensional

It is not clear to me what value it refers to, but I believe this means that LinearInterpolation object has a f32 2-dimentional array as ts.

Inspecting the ts and ys values we get some pretty nonsensical values(to me at least).

    print(dfdu.ts)
[[-1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [-1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]]
    print(dfdu.ys)
[[[1. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 1.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]
  [0. 0.]]]

Is this expected behavior?

Should I pursue other solutions?

Is there a more direct/standard way to calculate the sensitivity of my dynamical system?

Thank you very much for the attention and support. All assistance is welcomed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions