@@ -427,37 +427,40 @@ def _evol(t: Tensor) -> Tensor:
427427ed_evol = hamiltonian_evol
428428
429429
430- @partial (arg_alias , alias_dict = {"h_fun" : ["hamiltonian" ], "t" : ["times" ]})
431- def evol_local (
432- c : Circuit ,
433- index : Sequence [int ],
434- h_fun : Callable [..., Tensor ],
435- t : float ,
436- * args : Any ,
437- ** solver_kws : Any ,
438- ) -> Circuit :
439- """
440- ode evolution of time dependent Hamiltonian on circuit of given indices
441- [only jax backend support for now]
430+ def _solve_ode (f , s , times , args , solver_kws ):
431+ rtol = solver_kws .get ("rtol" , 1e-12 )
432+ atol = solver_kws .get ("atol" , 1e-12 )
442433
443- :param c: _description_
444- :type c: Circuit
445- :param index: qubit sites to evolve
446- :type index: Sequence[int]
447- :param h_fun: h_fun should return a dense Hamiltonian matrix
448- with input arguments time and *args
449- :type h_fun: Callable[..., Tensor]
450- :param t: evolution time
451- :type t: float
452- :return: _description_
453- :rtype: Circuit
454- """
455- s = c .state ()
456- n = int (np .log2 (s .shape [- 1 ]) + 1e-7 )
457- if isinstance (t , float ):
458- t = backend .stack ([0.0 , t ])
459- s1 = ode_evol_local (h_fun , s , t , index , None , * args , ** solver_kws )
460- return type (c )(n , inputs = s1 [- 1 ])
434+ ts = backend .convert_to_tensor (times )
435+ ts = backend .cast (ts , dtype = rdtypestr )
436+
437+ if (solver := solver_kws .get ("solver" , "Dopri5" )) == "Dopri5" :
438+ from jax .experimental .ode import odeint
439+ s1 = odeint (f , s , ts , rtol = rtol , atol = atol , * args )
440+
441+ else :
442+ import diffrax , warnings
443+ # Ignore complex warning
444+ warnings .simplefilter ("ignore" , category = UserWarning , append = True )
445+ dt0 = solver_kws .get ("dt0" , 0.01 )
446+ all_solvers = {"Tsit5" : diffrax .Tsit5 , "Dopri8" : diffrax .Dopri8 , "Kvaerno5" : diffrax .Kvaerno5 }
447+
448+ # ODE
449+ term = diffrax .ODETerm (lambda t , y , args : f (y , t , * args ))
450+
451+ # solve ODE
452+ s1 = diffrax .diffeqsolve (
453+ terms = term ,
454+ solver = all_solvers [solver ](),
455+ t0 = times [0 ],
456+ t1 = times [- 1 ],
457+ dt0 = dt0 ,
458+ y0 = s ,
459+ saveat = diffrax .SaveAt (ts = times ),
460+ args = args ,
461+ stepsize_controller = diffrax .PIDController (rtol = rtol , atol = atol ),
462+ ).ys
463+ return s1
461464
462465
463466def ode_evol_local (
@@ -466,14 +469,17 @@ def ode_evol_local(
466469 times : Tensor ,
467470 index : Sequence [int ],
468471 callback : Optional [Callable [..., Tensor ]] = None ,
469- * args : Any ,
470- ** solver_kws : Any ,
472+ args : tuple | list = tuple () ,
473+ solver_kws : dict = dict () ,
471474) -> Tensor :
472475 """
473476 ODE-based time evolution for a time-dependent Hamiltonian acting on a subsystem of qubits.
474477
475478 This function solves the time-dependent Schrodinger equation using numerical ODE integration.
476- The Hamiltonian is applied only to a specific subset of qubits (indices) in the system.
479+ The Hamiltonian is applied only to a specific subset of qubits (indices) in the system.
480+
481+ If the solver is 'Dopri5' (default), calls `jax.experimental.ode.odeint`;
482+ otherwise calls `diffrax`.
477483
478484 Note: This function currently only supports the JAX backend.
479485
@@ -489,14 +495,17 @@ def ode_evol_local(
489495 :param callback: Optional function to apply to the state at each time step.
490496 :type callback: Optional[Callable[..., Tensor]]
491497 :param args: Additional arguments to pass to the Hamiltonian function.
492- :param solver_kws: Additional keyword arguments to pass to the ODE solver.
498+ :type args: tuple | list
499+ :param solver_kws: Additional keyword arguments to pass to the ODE solver.
500+ The solver type can be specified: {'Dopri5' (default), 'Tsit5', 'Dopri8', 'Kvaerno5'}.
501+ rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would like the numerical approximation to your equation.
502+ dt0 (default: 0.01) specifies the initial step size.
503+ :type solver_kws: dict
493504 :return: Evolved quantum states at the specified time points. If callback is provided,
494505 returns the callback results; otherwise returns the state vectors.
495506 :rtype: Tensor
496507 """
497- from jax .experimental .ode import odeint
498508
499- s = initial_state
500509 n = int (np .log2 (backend .shape_tuple (initial_state )[- 1 ]) + 1e-7 )
501510 l = len (index )
502511
@@ -517,47 +526,20 @@ def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
517526 y = contractor ([y , h ], output_edge_order = edges )
518527 return backend .reshape (y .tensor , [- 1 ])
519528
520- ts = backend .convert_to_tensor (times )
521- ts = backend .cast (ts , dtype = rdtypestr )
522- s1 = odeint (f , s , ts , * args , ** solver_kws )
523- if not callback :
529+ s1 = _solve_ode (f , initial_state , times , args , solver_kws )
530+
531+ if callback is None :
524532 return s1
525- return backend .stack ([callback (s1 [i ]) for i in range (len (s1 ))])
526-
527-
528- @partial (arg_alias , alias_dict = {"h_fun" : ["hamiltonian" ], "t" : ["times" ]})
529- def evol_global (
530- c : Circuit , h_fun : Callable [..., Tensor ], t : float , * args : Any , ** solver_kws : Any
531- ) -> Circuit :
532- """
533- ode evolution of time dependent Hamiltonian on circuit of all qubits
534- [only jax backend support for now]
535-
536- :param c: _description_
537- :type c: Circuit
538- :param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
539- with input arguments time and *args
540- :type h_fun: Callable[..., Tensor]
541- :param t: _description_
542- :type t: float
543- :return: _description_
544- :rtype: Circuit
545- """
546- s = c .state ()
547- n = c ._nqubits
548- if isinstance (t , float ):
549- t = backend .stack ([0.0 , t ])
550- s1 = ode_evol_global (h_fun , s , t , None , * args , ** solver_kws )
551- return type (c )(n , inputs = s1 [- 1 ])
533+ return backend .stack ([callback (a_state ) for a_state in s1 ])
552534
553535
554536def ode_evol_global (
555537 hamiltonian : Callable [..., Tensor ],
556538 initial_state : Tensor ,
557539 times : Tensor ,
558540 callback : Optional [Callable [..., Tensor ]] = None ,
559- * args : Any ,
560- ** solver_kws : Any ,
541+ args : tuple | list = tuple () ,
542+ solver_kws : dict = dict () ,
561543) -> Tensor :
562544 """
563545 ODE-based time evolution for a time-dependent Hamiltonian acting on the entire system.
@@ -566,6 +548,9 @@ def ode_evol_global(
566548 The Hamiltonian is applied to the full system and should be provided in sparse matrix format
567549 for efficiency.
568550
551+ If the solver is 'Dopri5' (default), calls `jax.experimental.ode.odeint`;
552+ otherwise calls `diffrax`.
553+
569554 Note: This function currently only supports the JAX backend.
570555
571556 :param hamiltonian: A function that returns a sparse Hamiltonian matrix for the full system.
@@ -578,25 +563,200 @@ def ode_evol_global(
578563 :param callback: Optional function to apply to the state at each time step.
579564 :type callback: Optional[Callable[..., Tensor]]
580565 :param args: Additional arguments to pass to the Hamiltonian function.
581- :param solver_kws: Additional keyword arguments to pass to the ODE solver.
566+ :type args: tuple | list
567+ :param solver_kws: Additional keyword arguments to pass to the ODE solver.
568+ The solver type can be specified: {'Dopri5' (default), 'Tsit5', 'Dopri8', 'Kvaerno5'}.
569+ rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would like the numerical approximation to your equation.
570+ dt0 (default: 0.01) specifies the initial step size.
571+ :type solver_kws: dict
582572 :return: Evolved quantum states at the specified time points. If callback is provided,
583573 returns the callback results; otherwise returns the state vectors.
584574 :rtype: Tensor
585575 """
586- from jax .experimental .ode import odeint
587-
588- s = initial_state
589- ts = backend .convert_to_tensor (times )
590- ts = backend .cast (ts , dtype = rdtypestr )
591576
592577 def f (y : Tensor , t : Tensor , * args : Any ) -> Tensor :
593578 h = - 1.0j * hamiltonian (t , * args )
594579 return backend .sparse_dense_matmul (h , y )
580+
581+ s1 = _solve_ode (f , initial_state , times , args , solver_kws )
595582
596- s1 = odeint (f , s , ts , * args , ** solver_kws )
597- if not callback :
583+ if callback is None :
598584 return s1
599- return backend .stack ([callback (s1 [i ]) for i in range (len (s1 ))])
585+ return backend .stack ([callback (a_state ) for a_state in s1 ])
586+
587+
588+ @partial (arg_alias , alias_dict = {"h_fun" : ["hamiltonian" ], "t" : ["times" ]})
589+ def evol_local (
590+ c : Circuit ,
591+ index : Sequence [int ],
592+ h_fun : Callable [..., Tensor ],
593+ t : float ,
594+ * args : Any ,
595+ ** solver_kws : Any ,
596+ ) -> Circuit :
597+ """
598+ ode evolution of time dependent Hamiltonian on circuit of given indices
599+ [only jax backend support for now]
600+
601+ :param c: _description_
602+ :type c: Circuit
603+ :param index: qubit sites to evolve
604+ :type index: Sequence[int]
605+ :param h_fun: h_fun should return a dense Hamiltonian matrix
606+ with input arguments time and *args
607+ :type h_fun: Callable[..., Tensor]
608+ :param t: evolution time
609+ :type t: float
610+ :return: _description_
611+ :rtype: Circuit
612+ """
613+ s = c .state ()
614+ n = int (np .log2 (s .shape [- 1 ]) + 1e-7 )
615+ if isinstance (t , float ):
616+ t = backend .stack ([0.0 , t ])
617+ s1 = ode_evol_local (h_fun , s , t , index , None , * args , ** solver_kws )
618+ return type (c )(n , inputs = s1 [- 1 ])
619+
620+
621+ # def ode_evol_local(
622+ # hamiltonian: Callable[..., Tensor],
623+ # initial_state: Tensor,
624+ # times: Tensor,
625+ # index: Sequence[int],
626+ # callback: Optional[Callable[..., Tensor]] = None,
627+ # *args: Any,
628+ # **solver_kws: Any,
629+ # ) -> Tensor:
630+ # """
631+ # ODE-based time evolution for a time-dependent Hamiltonian acting on a subsystem of qubits.
632+
633+ # This function solves the time-dependent Schrodinger equation using numerical ODE integration.
634+ # The Hamiltonian is applied only to a specific subset of qubits (indices) in the system.
635+
636+ # Note: This function currently only supports the JAX backend.
637+
638+ # :param hamiltonian: A function that returns a dense Hamiltonian matrix for the specified
639+ # subsystem size. The function signature should be hamiltonian(time, *args) -> Tensor.
640+ # :type hamiltonian: Callable[..., Tensor]
641+ # :param initial_state: The initial quantum state vector of the full system.
642+ # :type initial_state: Tensor
643+ # :param times: Time points for which to compute the evolution. Should be a 1D array of times.
644+ # :type times: Tensor
645+ # :param index: Indices of qubits where the Hamiltonian is applied.
646+ # :type index: Sequence[int]
647+ # :param callback: Optional function to apply to the state at each time step.
648+ # :type callback: Optional[Callable[..., Tensor]]
649+ # :param args: Additional arguments to pass to the Hamiltonian function.
650+ # :param solver_kws: Additional keyword arguments to pass to the ODE solver.
651+ # :return: Evolved quantum states at the specified time points. If callback is provided,
652+ # returns the callback results; otherwise returns the state vectors.
653+ # :rtype: Tensor
654+ # """
655+ # from jax.experimental.ode import odeint
656+
657+ # s = initial_state
658+ # n = int(np.log2(backend.shape_tuple(initial_state)[-1]) + 1e-7)
659+ # l = len(index)
660+
661+ # def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
662+ # y = backend.reshape2(y)
663+ # y = Gate(y)
664+ # h = -1.0j * hamiltonian(t, *args)
665+ # h = backend.reshape2(h)
666+ # h = Gate(h)
667+ # edges = []
668+ # for i in range(n):
669+ # if i not in index:
670+ # edges.append(y[i])
671+ # else:
672+ # j = index.index(i)
673+ # edges.append(h[j])
674+ # h[j + l] ^ y[i]
675+ # y = contractor([y, h], output_edge_order=edges)
676+ # return backend.reshape(y.tensor, [-1])
677+
678+ # ts = backend.convert_to_tensor(times)
679+ # ts = backend.cast(ts, dtype=rdtypestr)
680+ # s1 = odeint(f, s, ts, *args, **solver_kws)
681+ # if not callback:
682+ # return s1
683+ # return backend.stack([callback(s1[i]) for i in range(len(s1))])
684+
685+
686+ @partial (arg_alias , alias_dict = {"h_fun" : ["hamiltonian" ], "t" : ["times" ]})
687+ def evol_global (
688+ c : Circuit , h_fun : Callable [..., Tensor ], t : float , * args : Any , ** solver_kws : Any
689+ ) -> Circuit :
690+ """
691+ ode evolution of time dependent Hamiltonian on circuit of all qubits
692+ [only jax backend support for now]
693+
694+ :param c: _description_
695+ :type c: Circuit
696+ :param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
697+ with input arguments time and *args
698+ :type h_fun: Callable[..., Tensor]
699+ :param t: _description_
700+ :type t: float
701+ :return: _description_
702+ :rtype: Circuit
703+ """
704+ s = c .state ()
705+ n = c ._nqubits
706+ if isinstance (t , float ):
707+ t = backend .stack ([0.0 , t ])
708+ s1 = ode_evol_global (h_fun , s , t , None , * args , ** solver_kws )
709+ return type (c )(n , inputs = s1 [- 1 ])
710+
711+
712+ # def ode_evol_global(
713+ # hamiltonian: Callable[..., Tensor],
714+ # initial_state: Tensor,
715+ # times: Tensor,
716+ # callback: Optional[Callable[..., Tensor]] = None,
717+ # *args: Any,
718+ # **solver_kws: Any,
719+ # ) -> Tensor:
720+ # """
721+ # ODE-based time evolution for a time-dependent Hamiltonian acting on the entire system.
722+
723+ # This function solves the time-dependent Schrodinger equation using numerical ODE integration.
724+ # The Hamiltonian is applied to the full system and should be provided in sparse matrix format
725+ # for efficiency.
726+
727+ # Note: This function currently only supports the JAX backend.
728+
729+ # :param hamiltonian: A function that returns a sparse Hamiltonian matrix for the full system.
730+ # The function signature should be hamiltonian(time, *args) -> Tensor.
731+ # :type hamiltonian: Callable[..., Tensor]
732+ # :param initial_state: The initial quantum state vector.
733+ # :type initial_state: Tensor
734+ # :param times: Time points for which to compute the evolution. Should be a 1D array of times.
735+ # :type times: Tensor
736+ # :param callback: Optional function to apply to the state at each time step.
737+ # :type callback: Optional[Callable[..., Tensor]]
738+ # :param args: Additional arguments to pass to the Hamiltonian function.
739+ # :param solver_kws: Additional keyword arguments to pass to the ODE solver.
740+ # :return: Evolved quantum states at the specified time points. If callback is provided,
741+ # returns the callback results; otherwise returns the state vectors.
742+ # :rtype: Tensor
743+ # """
744+ # from jax.experimental.ode import odeint
745+
746+ # s = initial_state
747+ # ts = backend.convert_to_tensor(times)
748+ # ts = backend.cast(ts, dtype=rdtypestr)
749+
750+ # def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
751+ # h = -1.0j * hamiltonian(t, *args)
752+ # return backend.sparse_dense_matmul(h, y)
753+
754+ # s1 = odeint(f, s, ts, *args, **solver_kws)
755+ # if not callback:
756+ # return s1
757+ # return backend.stack([callback(s1[i]) for i in range(len(s1))])
758+
759+
600760
601761
602762def chebyshev_evol (
0 commit comments