Skip to content

Commit ee967c5

Browse files
committed
add test to jax output shape
1 parent c02dfc7 commit ee967c5

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/test_jax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def test_network(self):
104104

105105
model_jax.run(chunkwise=True, bold=True, append_outputs=True, chunksize=20000)
106106

107+
# assert same output shape
108+
self.assertTrue(model.exc.shape == model_jax.exc.shape)
109+
107110
# jit changes the exact numerics of outputs
108111
self.assertTrue(np.allclose(model.BOLD.BOLD, model_jax.BOLD.BOLD, rtol=1e-3))
109112

0 commit comments

Comments
 (0)