@@ -767,14 +767,11 @@ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates:
767767 """
768768 # Ensure dtype consistency across backends (especially torch) by explicitly
769769 # casting size and lattice_vectors to the same floating dtype used internally.
770- # Strategy: prefer existing lattice_vectors dtype, fallback to float32 for efficiency,
771- # then float64 for precision. This avoids dtype mismatches in vectorized ops.
772- target_dt = None
773- try :
774- # prefer existing lattice_vectors dtype if possible
775- target_dt = backend .dtype (self .lattice_vectors ) # type: ignore
776- except (AttributeError , TypeError ): # pragma: no cover - defensive
777- target_dt = "float32"
770+ # Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype,
771+ # fall back to float32 to avoid mixed-precision issues in vectorized ops.
772+ # Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor`
773+ # in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except.
774+ target_dt = str (backend .dtype (self .lattice_vectors )) # type: ignore
778775 if target_dt not in ("float32" , "float64" ):
779776 # fallback for unusual dtypes
780777 target_dt = "float32"
@@ -1626,21 +1623,20 @@ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice":
16261623 # Unzip the list of tuples into separate lists of identifiers and coordinates
16271624 _ , identifiers , _ = zip (* all_sites_info )
16281625
1629- # Detach-and-copy coordinates in a backend-agnostic way.
1630- # Rationale (answering reviewer question "why not keep backend-dependent form?"):
1631- # - Passing a tuple/list of backend tensors (e.g., per-row slices) into
1632- # convert_to_tensor can fail on some backends (torch.tensor(list_of_tensors) ValueError),
1633- # whereas a plain nested Python list is accepted everywhere.
1634- # - We want CustomizeLattice to be decoupled from the original lattice's computation
1635- # graph and device state (CPU/GPU), so we materialize numeric values here.
1636- # - This is a one-shot conversion of the full coordinate array, simpler and faster
1637- # than iterating per row, while preserving the same numeric content.
1638- coords_py = backend .numpy (lattice ._coordinates ).tolist ()
1626+ # Detach-and-copy coordinates while remaining in tensor form to avoid
1627+ # host roundtrips and device/dtype changes; this keeps CustomizeLattice
1628+ # decoupled from the original graph but backend-friendly.
1629+ # Some backends (e.g., NumPy) don't implement stop_gradient; fall back.
1630+ try :
1631+ coords_detached = backend .stop_gradient (lattice ._coordinates )
1632+ except NotImplementedError :
1633+ coords_detached = lattice ._coordinates
1634+ coords_tensor = backend .copy (coords_detached )
16391635
16401636 return cls (
16411637 dimensionality = lattice .dimensionality ,
16421638 identifiers = list (identifiers ),
1643- coordinates = coords_py ,
1639+ coordinates = coords_tensor ,
16441640 )
16451641
16461642 def add_sites (
0 commit comments