From edbf1d91283b535fa1f30d0454b90ae67bc5e040 Mon Sep 17 00:00:00 2001 From: abhinavd Date: Thu, 16 Jan 2025 15:47:49 -0800 Subject: [PATCH 1/2] Fix JAX eigh --- tensorcircuit/backends/jax_ops.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorcircuit/backends/jax_ops.py b/tensorcircuit/backends/jax_ops.py index 2eaf2f1b..bc2701a4 100644 --- a/tensorcircuit/backends/jax_ops.py +++ b/tensorcircuit/backends/jax_ops.py @@ -151,11 +151,16 @@ def _QrGradSquareAndDeepMatrices(q: Array, r: Array, dq: Array, dr: Array) -> Ar @jax.custom_vjp def adaware_eigh(A: Array) -> Array: - return jnp.linalg.eigh(A) + result = jnp.linalg.eigh(A) + e = result.eigenvalues + v = result.eigenvectors + return e, v def jaxeigh_fwd(A: Array) -> Array: - e, v = jnp.linalg.eigh(A) + result = jnp.linalg.eigh(A) + e = result.eigenvalues + v = result.eigenvectors return (e, v), (A, e, v) From 7df72d0cd810ba52c60c1dfae2947d778b2e3c0c Mon Sep 17 00:00:00 2001 From: abhinavd Date: Fri, 17 Jan 2025 22:21:04 -0800 Subject: [PATCH 2/2] Making compatible with older JAX versions --- tensorcircuit/backends/jax_ops.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tensorcircuit/backends/jax_ops.py b/tensorcircuit/backends/jax_ops.py index bc2701a4..b22d0a70 100644 --- a/tensorcircuit/backends/jax_ops.py +++ b/tensorcircuit/backends/jax_ops.py @@ -151,16 +151,12 @@ def _QrGradSquareAndDeepMatrices(q: Array, r: Array, dq: Array, dr: Array) -> Ar @jax.custom_vjp def adaware_eigh(A: Array) -> Array: - result = jnp.linalg.eigh(A) - e = result.eigenvalues - v = result.eigenvectors + e, v = jnp.linalg.eigh(A) return e, v def jaxeigh_fwd(A: Array) -> Array: - result = jnp.linalg.eigh(A) - e = result.eigenvalues - v = result.eigenvectors + e, v = jnp.linalg.eigh(A) return (e, v), (A, e, v)