1919from .saveat import SaveAt , SubSaveAt
2020from .solution import is_okay , is_successful , RESULTS , Solution
2121from .solver import (
22+ AbstractImplicitSolver ,
2223 AbstractItoSolver ,
2324 AbstractSolver ,
2425 AbstractStratonovichSolver ,
@@ -605,6 +606,18 @@ def diffeqsolve(
605606 pred = (t1 - t0 ) * dt0 < 0
606607 dt0 = eqxi .error_if (jnp .array (dt0 ), pred , msg )
607608
609+ # Error checking and warning for complex dtypes
610+ if any (jtu .tree_leaves (jtu .tree_map (jnp .iscomplexobj , y0 ))):
611+ if isinstance (solver , AbstractImplicitSolver ):
612+ raise ValueError (
613+ "Implicit solvers in conjunction with complex dtypes is currently not "
614+ "supported."
615+ )
616+ warnings .warn (
617+ "Complex dtype support is work in progress, please read "
618+ "https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully."
619+ )
620+
608621 # Backward compatibility
609622 if isinstance (
610623 solver , (EulerHeun , ItoMilstein , StratonovichMilstein )
@@ -664,8 +677,10 @@ def _get_subsaveat_ts(saveat):
664677 )
665678
666679 # Time will affect state, so need to promote the state dtype as well if necessary.
680+ # fixing issue with float64 and weak dtypes, see discussion at:
681+ # https://github.com/patrick-kidger/diffrax/pull/197#discussion_r1130173527
667682 def _promote (yi ):
668- _dtype = jnp .result_type (yi , * timelikes ) # noqa: F821
683+ _dtype = jnp .result_type (yi , dtype ) # noqa: F821
669684 return jnp .asarray (yi , dtype = _dtype )
670685
671686 y0 = jtu .tree_map (_promote , y0 )
@@ -759,7 +774,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
759774 save_index = 0
760775 ts = jnp .full (out_size , jnp .inf )
761776 struct = eqx .filter_eval_shape (subsaveat .fn , t0 , y0 , args )
762- ys = jtu .tree_map (lambda y : jnp .full ((out_size ,) + y .shape , jnp .inf ), struct )
777+ ys = jtu .tree_map (
778+ lambda y : jnp .full ((out_size ,) + y .shape , jnp .inf , dtype = y .dtype ), struct
779+ )
763780 return SaveState (
764781 ts = ts , ys = ys , save_index = save_index , saveat_ts_index = saveat_ts_index
765782 )
@@ -779,7 +796,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
779796 solver .step , terms , tprev , tnext , y0 , args , solver_state , made_jump
780797 )
781798 dense_ts = jnp .full (max_steps + 1 , jnp .inf )
782- _make_full = lambda x : jnp .full ((max_steps ,) + jnp .shape (x ), jnp .inf )
799+ _make_full = lambda x : jnp .full (
800+ (max_steps ,) + jnp .shape (x ), jnp .inf , dtype = x .dtype
801+ )
783802 dense_infos = jtu .tree_map (_make_full , dense_info )
784803 dense_save_index = 0
785804 else :
0 commit comments