Skip to content

Commit edbf1d9

Browse files
author
abhinavd
committed
Fix JAX eigh
1 parent 72b70ed commit edbf1d9

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tensorcircuit/backends/jax_ops.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,16 @@ def _QrGradSquareAndDeepMatrices(q: Array, r: Array, dq: Array, dr: Array) -> Ar
151151

152152
@jax.custom_vjp
153153
def adaware_eigh(A: Array) -> Array:
154-
return jnp.linalg.eigh(A)
154+
result = jnp.linalg.eigh(A)
155+
e = result.eigenvalues
156+
v = result.eigenvectors
157+
return e, v
155158

156159

157160
def jaxeigh_fwd(A: Array) -> Array:
158-
e, v = jnp.linalg.eigh(A)
161+
result = jnp.linalg.eigh(A)
162+
e = result.eigenvalues
163+
v = result.eigenvectors
159164
return (e, v), (A, e, v)
160165

161166

0 commit comments

Comments
 (0)