@@ -36,17 +36,18 @@ def derivative(self, t):
3636 term = diffrax .ControlTerm (vector_field , control )
3737 args = getkey ()
3838 dx = term .contr (0 , 1 )
39- vf = term .vf (0 , None , args )
40- vf_prod = term .vf_prod (0 , None , args , dx )
39+ y = jnp .array ([1.0 , 2.0 , 3.0 ])
40+ vf = term .vf (0 , y , args )
41+ vf_prod = term .vf_prod (0 , y , args , dx )
4142 assert dx .shape == (2 ,)
4243 assert vf .shape == (3 , 2 )
4344 assert vf_prod .shape == (3 ,)
4445 assert shaped_allclose (vf_prod , term .prod (vf , dx ))
4546
4647 term = term .to_ode ()
4748 dt = term .contr (0 , 1 )
48- vf = term .vf (0 , None , args )
49- vf_prod = term .vf_prod (0 , None , args , dt )
49+ vf = term .vf (0 , y , args )
50+ vf_prod = term .vf_prod (0 , y , args , dt )
5051 assert vf .shape == (3 ,)
5152 assert vf_prod .shape == (3 ,)
5253 assert shaped_allclose (vf_prod , term .prod (vf , dt ))
@@ -70,17 +71,18 @@ def derivative(self, t):
7071 term = diffrax .WeaklyDiagonalControlTerm (vector_field , control )
7172 args = getkey ()
7273 dx = term .contr (0 , 1 )
73- vf = term .vf (0 , None , args )
74- vf_prod = term .vf_prod (0 , None , args , dx )
74+ y = jnp .array ([1.0 , 2.0 , 3.0 ])
75+ vf = term .vf (0 , y , args )
76+ vf_prod = term .vf_prod (0 , y , args , dx )
7577 assert dx .shape == (3 ,)
7678 assert vf .shape == (3 ,)
7779 assert vf_prod .shape == (3 ,)
7880 assert shaped_allclose (vf_prod , term .prod (vf , dx ))
7981
8082 term = term .to_ode ()
8183 dt = term .contr (0 , 1 )
82- vf = term .vf (0 , None , args )
83- vf_prod = term .vf_prod (0 , None , args , dt )
84+ vf = term .vf (0 , y , args )
85+ vf_prod = term .vf_prod (0 , y , args , dt )
8486 assert vf .shape == (3 ,)
8587 assert vf_prod .shape == (3 ,)
8688 assert shaped_allclose (vf_prod , term .prod (vf , dt ))
0 commit comments