Skip to content

Commit 04aca93

Browse files
committed
fix according to the review
1 parent 0b38522 commit 04aca93

File tree

3 files changed

+53
-61
lines changed

3 files changed

+53
-61
lines changed

examples/lennard_jones_optimization.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -79,42 +79,41 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0):
7979
final_a = K.exp(log_a)
8080
final_energy = calculate_potential(log_a)
8181

82-
if not np.isnan(K.numpy(final_energy)):
83-
print("\nOptimization finished!")
84-
print(f"Final optimized lattice constant: {final_a:.6f}")
85-
print(f"Corresponding minimum total energy: {final_energy:.6f}")
86-
87-
# Vectorized calculation for the potential curve
88-
a_vals = np.linspace(0.8, 1.5, 200)
89-
log_a_vals = K.log(K.convert_to_tensor(a_vals))
90-
91-
# Use vmap to create a vectorized version of the potential function
92-
vmap_potential = K.vmap(lambda la: calculate_potential(la))
93-
potential_curve = vmap_potential(log_a_vals)
94-
95-
plt.figure(figsize=(10, 6))
96-
plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue")
97-
plt.scatter(
98-
history["a"],
99-
history["energy"],
100-
color="red",
101-
s=20,
102-
zorder=5,
103-
label="Optimization Steps",
104-
)
105-
plt.scatter(
106-
final_a,
107-
final_energy,
108-
color="green",
109-
s=100,
110-
zorder=6,
111-
marker="*",
112-
label="Final Optimized Point",
113-
)
114-
115-
plt.title("Lennard-Jones Potential Optimization")
116-
plt.xlabel("Lattice Constant (a)")
117-
plt.ylabel("Total Potential Energy")
118-
plt.legend()
119-
plt.grid(True)
120-
plt.show()
82+
print("\nOptimization finished!")
83+
print(f"Final optimized lattice constant: {final_a:.6f}")
84+
print(f"Corresponding minimum total energy: {final_energy:.6f}")
85+
86+
# Vectorized calculation for the potential curve
87+
a_vals = np.linspace(0.8, 1.5, 200)
88+
log_a_vals = K.log(K.convert_to_tensor(a_vals))
89+
90+
# Use vmap to create a vectorized version of the potential function
91+
vmap_potential = K.vmap(lambda la: calculate_potential(la))
92+
potential_curve = vmap_potential(log_a_vals)
93+
94+
plt.figure(figsize=(10, 6))
95+
plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue")
96+
plt.scatter(
97+
history["a"],
98+
history["energy"],
99+
color="red",
100+
s=20,
101+
zorder=5,
102+
label="Optimization Steps",
103+
)
104+
plt.scatter(
105+
final_a,
106+
final_energy,
107+
color="green",
108+
s=100,
109+
zorder=6,
110+
marker="*",
111+
label="Final Optimized Point",
112+
)
113+
114+
plt.title("Lennard-Jones Potential Optimization")
115+
plt.xlabel("Lattice Constant (a)")
116+
plt.ylabel("Total Potential Energy")
117+
plt.legend()
118+
plt.grid(True)
119+
plt.show()

tensorcircuit/templates/lattice.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -767,14 +767,11 @@ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates:
767767
"""
768768
# Ensure dtype consistency across backends (especially torch) by explicitly
769769
# casting size and lattice_vectors to the same floating dtype used internally.
770-
# Strategy: prefer existing lattice_vectors dtype, fallback to float32 for efficiency,
771-
# then float64 for precision. This avoids dtype mismatches in vectorized ops.
772-
target_dt = None
773-
try:
774-
# prefer existing lattice_vectors dtype if possible
775-
target_dt = backend.dtype(self.lattice_vectors) # type: ignore
776-
except (AttributeError, TypeError): # pragma: no cover - defensive
777-
target_dt = "float32"
770+
# Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype,
771+
# fall back to float32 to avoid mixed-precision issues in vectorized ops.
772+
# Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor`
773+
# in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except.
774+
target_dt = str(backend.dtype(self.lattice_vectors)) # type: ignore
778775
if target_dt not in ("float32", "float64"):
779776
# fallback for unusual dtypes
780777
target_dt = "float32"
@@ -1626,21 +1623,20 @@ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice":
16261623
# Unzip the list of tuples into separate lists of identifiers and coordinates
16271624
_, identifiers, _ = zip(*all_sites_info)
16281625

1629-
# Detach-and-copy coordinates in a backend-agnostic way.
1630-
# Rationale (answering reviewer question "why not keep backend-dependent form?"):
1631-
# - Passing a tuple/list of backend tensors (e.g., per-row slices) into
1632-
# convert_to_tensor can fail on some backends (torch.tensor(list_of_tensors) ValueError),
1633-
# whereas a plain nested Python list is accepted everywhere.
1634-
# - We want CustomizeLattice to be decoupled from the original lattice's computation
1635-
# graph and device state (CPU/GPU), so we materialize numeric values here.
1636-
# - This is a one-shot conversion of the full coordinate array, simpler and faster
1637-
# than iterating per row, while preserving the same numeric content.
1638-
coords_py = backend.numpy(lattice._coordinates).tolist()
1626+
# Detach-and-copy coordinates while remaining in tensor form to avoid
1627+
# host roundtrips and device/dtype changes; this keeps CustomizeLattice
1628+
# decoupled from the original graph but backend-friendly.
1629+
# Some backends (e.g., NumPy) don't implement stop_gradient; fall back.
1630+
try:
1631+
coords_detached = backend.stop_gradient(lattice._coordinates)
1632+
except NotImplementedError:
1633+
coords_detached = lattice._coordinates
1634+
coords_tensor = backend.copy(coords_detached)
16391635

16401636
return cls(
16411637
dimensionality=lattice.dimensionality,
16421638
identifiers=list(identifiers),
1643-
coordinates=coords_py,
1639+
coordinates=coords_tensor,
16441640
)
16451641

16461642
def add_sites(

tests/test_lattice.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,14 +2090,12 @@ def test_layering_on_various_lattices(lattice_instance):
20902090
_validate_layers(bonds, layers)
20912091

20922092

2093-
# --- Regression tests for backend-scalar lattice constants (PR fix) ---
20942093
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
20952094
def test_square_lattice_accepts_backend_scalar_lattice_constant(backend):
20962095
"""
20972096
Ensure SquareLattice can be constructed when lattice_constant is a backend scalar tensor
20982097
(e.g., tf.constant, jnp.array, torch.tensor), without mixed-type errors.
20992098
"""
2100-
tc.set_backend(backend)
21012099

21022100
lc = tc.backend.convert_to_tensor(0.5)
21032101
lat = SquareLattice(size=(2, 2), lattice_constant=lc, pbc=False)
@@ -2121,7 +2119,6 @@ def test_rectangular_lattice_mixed_type_constants(backend):
21212119
RectangularLattice should accept a tuple where one constant is a backend scalar tensor
21222120
and the other is a Python float.
21232121
"""
2124-
tc.set_backend(backend)
21252122

21262123
ax = tc.backend.convert_to_tensor(0.5) # tensor scalar
21272124
ay = 2.0 # python float

0 commit comments

Comments
 (0)