Skip to content

Commit 2adfd52

Browse files
committed
Update timeevol.py
change two functions
1 parent cb8dbca commit 2adfd52

File tree

1 file changed

+238
-78
lines changed

1 file changed

+238
-78
lines changed

tensorcircuit/timeevol.py

Lines changed: 238 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -427,37 +427,40 @@ def _evol(t: Tensor) -> Tensor:
427427
ed_evol = hamiltonian_evol
428428

429429

430-
@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
431-
def evol_local(
432-
c: Circuit,
433-
index: Sequence[int],
434-
h_fun: Callable[..., Tensor],
435-
t: float,
436-
*args: Any,
437-
**solver_kws: Any,
438-
) -> Circuit:
439-
"""
440-
ode evolution of time dependent Hamiltonian on circuit of given indices
441-
[only jax backend support for now]
430+
def _solve_ode(f, s, times, args, solver_kws):
431+
rtol = solver_kws.get("rtol", 1e-12)
432+
atol = solver_kws.get("atol", 1e-12)
442433

443-
:param c: _description_
444-
:type c: Circuit
445-
:param index: qubit sites to evolve
446-
:type index: Sequence[int]
447-
:param h_fun: h_fun should return a dense Hamiltonian matrix
448-
with input arguments time and *args
449-
:type h_fun: Callable[..., Tensor]
450-
:param t: evolution time
451-
:type t: float
452-
:return: _description_
453-
:rtype: Circuit
454-
"""
455-
s = c.state()
456-
n = int(np.log2(s.shape[-1]) + 1e-7)
457-
if isinstance(t, float):
458-
t = backend.stack([0.0, t])
459-
s1 = ode_evol_local(h_fun, s, t, index, None, *args, **solver_kws)
460-
return type(c)(n, inputs=s1[-1])
434+
ts = backend.convert_to_tensor(times)
435+
ts = backend.cast(ts, dtype=rdtypestr)
436+
437+
if (solver := solver_kws.get("solver", "Dopri5")) == "Dopri5":
438+
from jax.experimental.ode import odeint
439+
s1 = odeint(f, s, ts, rtol=rtol, atol=atol, *args)
440+
441+
else:
442+
import diffrax, warnings
443+
# Ignore complex warning
444+
warnings.simplefilter("ignore", category=UserWarning, append=True)
445+
dt0 = solver_kws.get("dt0", 0.01)
446+
all_solvers = {"Tsit5": diffrax.Tsit5, "Dopri8": diffrax.Dopri8, "Kvaerno5": diffrax.Kvaerno5}
447+
448+
# ODE
449+
term = diffrax.ODETerm(lambda t, y, args: f(y, t, *args))
450+
451+
# solve ODE
452+
s1 = diffrax.diffeqsolve(
453+
terms = term,
454+
solver = all_solvers[solver](),
455+
t0 = times[0],
456+
t1 = times[-1],
457+
dt0 = dt0,
458+
y0 = s,
459+
saveat = diffrax.SaveAt(ts=times),
460+
args = args,
461+
stepsize_controller = diffrax.PIDController(rtol=rtol, atol=atol),
462+
).ys
463+
return s1
461464

462465

463466
def ode_evol_local(
@@ -466,14 +469,17 @@ def ode_evol_local(
466469
times: Tensor,
467470
index: Sequence[int],
468471
callback: Optional[Callable[..., Tensor]] = None,
469-
*args: Any,
470-
**solver_kws: Any,
472+
args: tuple | list = tuple(),
473+
solver_kws: dict = dict(),
471474
) -> Tensor:
472475
"""
473476
ODE-based time evolution for a time-dependent Hamiltonian acting on a subsystem of qubits.
474477
475478
This function solves the time-dependent Schrodinger equation using numerical ODE integration.
476-
The Hamiltonian is applied only to a specific subset of qubits (indices) in the system.
479+
The Hamiltonian is applied only to a specific subset of qubits (indices) in the system.
480+
481+
If the solver is 'Dopri5' (default), calls `jax.experimental.ode.odeint`;
482+
otherwise calls `diffrax`.
477483
478484
Note: This function currently only supports the JAX backend.
479485
@@ -489,14 +495,17 @@ def ode_evol_local(
489495
:param callback: Optional function to apply to the state at each time step.
490496
:type callback: Optional[Callable[..., Tensor]]
491497
:param args: Additional arguments to pass to the Hamiltonian function.
492-
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
498+
:type args: tuple | list
499+
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
500+
The solver type can be specified: {'Dopri5' (default), 'Tsit5', 'Dopri8', 'Kvaerno5'}.
501+
rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would like the numerical approximation to your equation.
502+
dt0 (default: 0.01) specifies the initial step size.
503+
:type solver_kws: dict
493504
:return: Evolved quantum states at the specified time points. If callback is provided,
494505
returns the callback results; otherwise returns the state vectors.
495506
:rtype: Tensor
496507
"""
497-
from jax.experimental.ode import odeint
498508

499-
s = initial_state
500509
n = int(np.log2(backend.shape_tuple(initial_state)[-1]) + 1e-7)
501510
l = len(index)
502511

@@ -517,47 +526,20 @@ def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
517526
y = contractor([y, h], output_edge_order=edges)
518527
return backend.reshape(y.tensor, [-1])
519528

520-
ts = backend.convert_to_tensor(times)
521-
ts = backend.cast(ts, dtype=rdtypestr)
522-
s1 = odeint(f, s, ts, *args, **solver_kws)
523-
if not callback:
529+
s1 = _solve_ode(f, initial_state, times, args, solver_kws)
530+
531+
if callback is None:
524532
return s1
525-
return backend.stack([callback(s1[i]) for i in range(len(s1))])
526-
527-
528-
@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
529-
def evol_global(
530-
c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any
531-
) -> Circuit:
532-
"""
533-
ode evolution of time dependent Hamiltonian on circuit of all qubits
534-
[only jax backend support for now]
535-
536-
:param c: _description_
537-
:type c: Circuit
538-
:param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
539-
with input arguments time and *args
540-
:type h_fun: Callable[..., Tensor]
541-
:param t: _description_
542-
:type t: float
543-
:return: _description_
544-
:rtype: Circuit
545-
"""
546-
s = c.state()
547-
n = c._nqubits
548-
if isinstance(t, float):
549-
t = backend.stack([0.0, t])
550-
s1 = ode_evol_global(h_fun, s, t, None, *args, **solver_kws)
551-
return type(c)(n, inputs=s1[-1])
533+
return backend.stack([callback(a_state) for a_state in s1])
552534

553535

554536
def ode_evol_global(
555537
hamiltonian: Callable[..., Tensor],
556538
initial_state: Tensor,
557539
times: Tensor,
558540
callback: Optional[Callable[..., Tensor]] = None,
559-
*args: Any,
560-
**solver_kws: Any,
541+
args: tuple | list = tuple(),
542+
solver_kws: dict = dict(),
561543
) -> Tensor:
562544
"""
563545
ODE-based time evolution for a time-dependent Hamiltonian acting on the entire system.
@@ -566,6 +548,9 @@ def ode_evol_global(
566548
The Hamiltonian is applied to the full system and should be provided in sparse matrix format
567549
for efficiency.
568550
551+
If the solver is 'Dopri5' (default), calls `jax.experimental.ode.odeint`;
552+
otherwise calls `diffrax`.
553+
569554
Note: This function currently only supports the JAX backend.
570555
571556
:param hamiltonian: A function that returns a sparse Hamiltonian matrix for the full system.
@@ -578,25 +563,200 @@ def ode_evol_global(
578563
:param callback: Optional function to apply to the state at each time step.
579564
:type callback: Optional[Callable[..., Tensor]]
580565
:param args: Additional arguments to pass to the Hamiltonian function.
581-
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
566+
:type args: tuple | list
567+
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
568+
The solver type can be specified: {'Dopri5' (default), 'Tsit5', 'Dopri8', 'Kvaerno5'}.
569+
rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would like the numerical approximation to your equation.
570+
dt0 (default: 0.01) specifies the initial step size.
571+
:type solver_kws: dict
582572
:return: Evolved quantum states at the specified time points. If callback is provided,
583573
returns the callback results; otherwise returns the state vectors.
584574
:rtype: Tensor
585575
"""
586-
from jax.experimental.ode import odeint
587-
588-
s = initial_state
589-
ts = backend.convert_to_tensor(times)
590-
ts = backend.cast(ts, dtype=rdtypestr)
591576

592577
def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
593578
h = -1.0j * hamiltonian(t, *args)
594579
return backend.sparse_dense_matmul(h, y)
580+
581+
s1 = _solve_ode(f, initial_state, times, args, solver_kws)
595582

596-
s1 = odeint(f, s, ts, *args, **solver_kws)
597-
if not callback:
583+
if callback is None:
598584
return s1
599-
return backend.stack([callback(s1[i]) for i in range(len(s1))])
585+
return backend.stack([callback(a_state) for a_state in s1])
586+
587+
588+
@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
589+
def evol_local(
590+
c: Circuit,
591+
index: Sequence[int],
592+
h_fun: Callable[..., Tensor],
593+
t: float,
594+
*args: Any,
595+
**solver_kws: Any,
596+
) -> Circuit:
597+
"""
598+
ode evolution of time dependent Hamiltonian on circuit of given indices
599+
[only jax backend support for now]
600+
601+
:param c: _description_
602+
:type c: Circuit
603+
:param index: qubit sites to evolve
604+
:type index: Sequence[int]
605+
:param h_fun: h_fun should return a dense Hamiltonian matrix
606+
with input arguments time and *args
607+
:type h_fun: Callable[..., Tensor]
608+
:param t: evolution time
609+
:type t: float
610+
:return: _description_
611+
:rtype: Circuit
612+
"""
613+
s = c.state()
614+
n = int(np.log2(s.shape[-1]) + 1e-7)
615+
if isinstance(t, float):
616+
t = backend.stack([0.0, t])
617+
s1 = ode_evol_local(h_fun, s, t, index, None, *args, **solver_kws)
618+
return type(c)(n, inputs=s1[-1])
619+
620+
621+
# def ode_evol_local(
622+
# hamiltonian: Callable[..., Tensor],
623+
# initial_state: Tensor,
624+
# times: Tensor,
625+
# index: Sequence[int],
626+
# callback: Optional[Callable[..., Tensor]] = None,
627+
# *args: Any,
628+
# **solver_kws: Any,
629+
# ) -> Tensor:
630+
# """
631+
# ODE-based time evolution for a time-dependent Hamiltonian acting on a subsystem of qubits.
632+
633+
# This function solves the time-dependent Schrodinger equation using numerical ODE integration.
634+
# The Hamiltonian is applied only to a specific subset of qubits (indices) in the system.
635+
636+
# Note: This function currently only supports the JAX backend.
637+
638+
# :param hamiltonian: A function that returns a dense Hamiltonian matrix for the specified
639+
# subsystem size. The function signature should be hamiltonian(time, *args) -> Tensor.
640+
# :type hamiltonian: Callable[..., Tensor]
641+
# :param initial_state: The initial quantum state vector of the full system.
642+
# :type initial_state: Tensor
643+
# :param times: Time points for which to compute the evolution. Should be a 1D array of times.
644+
# :type times: Tensor
645+
# :param index: Indices of qubits where the Hamiltonian is applied.
646+
# :type index: Sequence[int]
647+
# :param callback: Optional function to apply to the state at each time step.
648+
# :type callback: Optional[Callable[..., Tensor]]
649+
# :param args: Additional arguments to pass to the Hamiltonian function.
650+
# :param solver_kws: Additional keyword arguments to pass to the ODE solver.
651+
# :return: Evolved quantum states at the specified time points. If callback is provided,
652+
# returns the callback results; otherwise returns the state vectors.
653+
# :rtype: Tensor
654+
# """
655+
# from jax.experimental.ode import odeint
656+
657+
# s = initial_state
658+
# n = int(np.log2(backend.shape_tuple(initial_state)[-1]) + 1e-7)
659+
# l = len(index)
660+
661+
# def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
662+
# y = backend.reshape2(y)
663+
# y = Gate(y)
664+
# h = -1.0j * hamiltonian(t, *args)
665+
# h = backend.reshape2(h)
666+
# h = Gate(h)
667+
# edges = []
668+
# for i in range(n):
669+
# if i not in index:
670+
# edges.append(y[i])
671+
# else:
672+
# j = index.index(i)
673+
# edges.append(h[j])
674+
# h[j + l] ^ y[i]
675+
# y = contractor([y, h], output_edge_order=edges)
676+
# return backend.reshape(y.tensor, [-1])
677+
678+
# ts = backend.convert_to_tensor(times)
679+
# ts = backend.cast(ts, dtype=rdtypestr)
680+
# s1 = odeint(f, s, ts, *args, **solver_kws)
681+
# if not callback:
682+
# return s1
683+
# return backend.stack([callback(s1[i]) for i in range(len(s1))])
684+
685+
686+
@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
687+
def evol_global(
688+
c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any
689+
) -> Circuit:
690+
"""
691+
ode evolution of time dependent Hamiltonian on circuit of all qubits
692+
[only jax backend support for now]
693+
694+
:param c: _description_
695+
:type c: Circuit
696+
:param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
697+
with input arguments time and *args
698+
:type h_fun: Callable[..., Tensor]
699+
:param t: _description_
700+
:type t: float
701+
:return: _description_
702+
:rtype: Circuit
703+
"""
704+
s = c.state()
705+
n = c._nqubits
706+
if isinstance(t, float):
707+
t = backend.stack([0.0, t])
708+
s1 = ode_evol_global(h_fun, s, t, None, *args, **solver_kws)
709+
return type(c)(n, inputs=s1[-1])
710+
711+
712+
# def ode_evol_global(
713+
# hamiltonian: Callable[..., Tensor],
714+
# initial_state: Tensor,
715+
# times: Tensor,
716+
# callback: Optional[Callable[..., Tensor]] = None,
717+
# *args: Any,
718+
# **solver_kws: Any,
719+
# ) -> Tensor:
720+
# """
721+
# ODE-based time evolution for a time-dependent Hamiltonian acting on the entire system.
722+
723+
# This function solves the time-dependent Schrodinger equation using numerical ODE integration.
724+
# The Hamiltonian is applied to the full system and should be provided in sparse matrix format
725+
# for efficiency.
726+
727+
# Note: This function currently only supports the JAX backend.
728+
729+
# :param hamiltonian: A function that returns a sparse Hamiltonian matrix for the full system.
730+
# The function signature should be hamiltonian(time, *args) -> Tensor.
731+
# :type hamiltonian: Callable[..., Tensor]
732+
# :param initial_state: The initial quantum state vector.
733+
# :type initial_state: Tensor
734+
# :param times: Time points for which to compute the evolution. Should be a 1D array of times.
735+
# :type times: Tensor
736+
# :param callback: Optional function to apply to the state at each time step.
737+
# :type callback: Optional[Callable[..., Tensor]]
738+
# :param args: Additional arguments to pass to the Hamiltonian function.
739+
# :param solver_kws: Additional keyword arguments to pass to the ODE solver.
740+
# :return: Evolved quantum states at the specified time points. If callback is provided,
741+
# returns the callback results; otherwise returns the state vectors.
742+
# :rtype: Tensor
743+
# """
744+
# from jax.experimental.ode import odeint
745+
746+
# s = initial_state
747+
# ts = backend.convert_to_tensor(times)
748+
# ts = backend.cast(ts, dtype=rdtypestr)
749+
750+
# def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
751+
# h = -1.0j * hamiltonian(t, *args)
752+
# return backend.sparse_dense_matmul(h, y)
753+
754+
# s1 = odeint(f, s, ts, *args, **solver_kws)
755+
# if not callback:
756+
# return s1
757+
# return backend.stack([callback(s1[i]) for i in range(len(s1))])
758+
759+
600760

601761

602762
def chebyshev_evol(

0 commit comments

Comments
 (0)