Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions tests/test_timeevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,142 @@ def objective_function(params):
print(objective_function(tc.backend.ones(4)))


def test_ode_evol_jit_grad(highp, jaxb):
try:
import diffrax # pylint: disable=unused-import
except ImportError:
pytest.skip("diffrax not installed, skipping test")

zz_ham = tc.quantum.PauliStringSum2COO([[3, 3, 0, 0], [0, 3, 3, 0]], [1, 1])
x_ham = tc.quantum.PauliStringSum2COO([[1, 0, 0, 0], [0, 1, 0, 0]], [1, 1])

c = tc.Circuit(4)
c.x([1, 3])
psi0 = c.state()

# Example with parameterized Hamiltonian and optimization
def parametrized_hamiltonian(t, *params):
# params = [J0, J1, h0, h1] - parameters to optimize
J_t = params[0] + params[1] * tc.backend.sin(2.0 * t)
h_t = params[2] + params[3] * tc.backend.cos(1.5 * t)

return J_t * zz_ham + h_t * x_ham

def zz_correlation(state):
n = int(np.log2(state.shape[0]))
circuit = tc.Circuit(n, inputs=state)
return circuit.expectation_ps(z=[0, 1])

@tc.backend.jit
@tc.backend.value_and_grad
def kv_ode_solver_(params):
states = tc.timeevol.ode_evol_global(
parametrized_hamiltonian,
psi0,
tc.backend.convert_to_tensor([0, 10.0]),
None,
*params,
atol=1.0e-15,
rtol=1.0e-15,
solver="Kvaerno5",
ode_backend="diffrax",
)
return tc.backend.real(zz_correlation(states[-1]))

@tc.backend.jit
@tc.backend.value_and_grad
def ts_ode_solver_(params):
states = tc.timeevol.ode_evol_global(
parametrized_hamiltonian,
psi0,
tc.backend.convert_to_tensor([0, 10.0]),
None,
*params,
ode_backend="diffrax",
atol=1.0e-13,
rtol=1.0e-13,
dt0=0.005,
)
return tc.backend.real(zz_correlation(states[-1]))

@tc.backend.jit
@tc.backend.value_and_grad
def do5_ode_solver_(params):
states = tc.timeevol.ode_evol_global(
parametrized_hamiltonian,
psi0,
tc.backend.convert_to_tensor([0, 10.0]),
None,
*params,
)
return tc.backend.real(zz_correlation(states[-1]))

paras = np.random.rand(4)
s1 = kv_ode_solver_(paras)
s2 = ts_ode_solver_(paras)
s3 = do5_ode_solver_(paras)

v1, g1 = s1
v2, g2 = s2
v3, g3 = s3

assert (np.linalg.norm(v1 - v2) < 1e-8) & (np.linalg.norm(v1 - v3) < 1e-8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use np.testing.assert_allclose

assert (np.linalg.norm(g1 - g2) < 1e-8) & (np.linalg.norm(g1 - g3) < 1e-8)

######################################################################

def local_hamiltonian(t, Omega, phi):
angle = phi * t
coeff = Omega * tc.backend.cos(2.0 * t) # Amplitude modulation

# Single-qubit Rabi Hamiltonian (2x2 matrix)
hx = coeff * tc.backend.cos(angle) * tc.gates.x().tensor
hy = coeff * tc.backend.sin(angle) * tc.gates.y().tensor
return hx + hy

# Initial state: GHZ state |0000⟩ + |1111⟩
c = tc.Circuit(4)
c.h(0)
for i in range(3):
c.cnot(i, i + 1)
psi0 = c.state()

# Evolve with local Hamiltonian acting on qubit 1
@tc.backend.jit
@tc.backend.value_and_grad
def ts_ode_solver_local(paras):
states = tc.timeevol.ode_evol_local(
local_hamiltonian,
psi0,
tc.backend.convert_to_tensor([0, 10.0]),
[2], # Apply to qubit 1
None,
*paras, # Omega=1.0, phi=2.0
ode_backend="diffrax",
)
return tc.backend.real(zz_correlation(states[-1]))

@tc.backend.jit
@tc.backend.value_and_grad
def do5_ode_solver_local(paras):
states = tc.timeevol.ode_evol_local(
local_hamiltonian,
psi0,
tc.backend.convert_to_tensor([0, 10.0]),
[2], # Apply to qubit 1
None,
*paras, # Omega=1.0, phi=2.0
)
return tc.backend.real(zz_correlation(states[-1]))

paras = np.random.rand(2)
s1 = ts_ode_solver_local(paras)
s2 = do5_ode_solver_local(paras)
v1, g1 = s1
v2, g2 = s2
assert (np.linalg.norm(v1 - v2) < 1e-8) & (np.linalg.norm(g1 - g2) < 1e-8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.testing.assert_allclose



@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
def test_ed_evol(backend):
n = 4
Expand Down