Skip to content

Commit daa3ff2

Browse files
committed
fix according to the review
1 parent 589763e commit daa3ff2

File tree

7 files changed

+421
-249
lines changed

7 files changed

+421
-249
lines changed

examples/lennard_jones_optimization.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,7 @@
1212
import optax
1313
import numpy as np
1414
import matplotlib.pyplot as plt
15-
16-
# Try to enable JAX 64-bit precision if available (safe fallback)
17-
18-
try: # pragma: no cover - optional optimization
19-
from jax import config as jax_config # type: ignore
20-
21-
jax_config.update("jax_enable_x64", True)
22-
except Exception: # broad: environment may not have config attribute
23-
pass
24-
import tensorcircuit as tc # noqa: E402
15+
import tensorcircuit as tc
2516

2617

2718
tc.set_dtype("float64") # Use tc for universal control
@@ -58,9 +49,8 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0):
5849
return potential_energy
5950

6051

61-
# Create a lambda function for optimization
62-
potential_fun_for_grad = lambda log_a: calculate_potential(log_a)
63-
value_and_grad_fun = K.jit(K.value_and_grad(potential_fun_for_grad))
52+
# Create value and grad function for optimization
53+
value_and_grad_fun = K.jit(K.value_and_grad(calculate_potential))
6454

6555
optimizer = optax.adam(learning_rate=0.01)
6656

@@ -77,11 +67,7 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0):
7767
history["a"].append(K.exp(log_a))
7868
history["energy"].append(energy)
7969

80-
# Check for NaN gradients using TensorCircuit's backend-agnostic approach
81-
if K.sum(tc.num_to_tensor(np.isnan(K.numpy(grad)))) > 0:
82-
print(f"Gradient became NaN at iteration {i+1}. Stopping optimization.")
83-
print(f"Current energy: {energy}, Current log_a: {log_a}")
84-
break
70+
# (Removed previously added blanket NaN guard per reviewer request to keep example minimal.)
8571

8672
updates, opt_state = optimizer.update(grad, opt_state)
8773
log_a = optax.apply_updates(log_a, updates)

tensorcircuit/backends/abstract_backend.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -629,15 +629,27 @@ def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
629629

630630
def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
631631
"""
632-
Return coordinate matrices from coordinate vectors.
632+
Return coordinate matrices from coordinate vectors.
633633
634-
:param args: coordinate vectors
635-
:type args: Any
634+
:param args: coordinate vectors
635+
:type args: Any
636636
:param kwargs: keyword arguments for meshgrid, typically includes 'indexing'
637-
which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing)
638-
:type kwargs: Any
639-
:return: list of coordinate matrices
640-
:rtype: Any
637+
which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing).
638+
- 'ij': matrix indexing, first dimension corresponds to rows (default)
639+
- 'xy': Cartesian indexing, first dimension corresponds to columns
640+
Example:
641+
>>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy')
642+
Shapes:
643+
- x.shape == (2, 2) # rows correspond to y vector length
644+
- y.shape == (2, 2)
645+
Values:
646+
x = [[0, 1],
647+
[0, 1]]
648+
y = [[0, 0],
649+
[2, 2]]
650+
:type kwargs: Any
651+
:return: list of coordinate matrices
652+
:rtype: Any
641653
"""
642654
raise NotImplementedError(
643655
"Backend '{}' has not implemented `meshgrid`.".format(self.name)
@@ -797,6 +809,21 @@ def cast(self: Any, a: Tensor, dtype: str) -> Tensor:
797809
"Backend '{}' has not implemented `cast`.".format(self.name)
798810
)
799811

812+
def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor:
813+
"""
814+
Convert input to tensor.
815+
816+
:param a: input data to be converted
817+
:type a: Tensor
818+
:param dtype: target dtype, optional
819+
:type dtype: Optional[str]
820+
:return: converted tensor
821+
:rtype: Tensor
822+
"""
823+
raise NotImplementedError(
824+
"Backend '{}' has not implemented `convert_to_tensor`.".format(self.name)
825+
)
826+
800827
def mod(self: Any, x: Tensor, y: Tensor) -> Tensor:
801828
"""
802829
Compute y-mod of x (negative number behavior is not guaranteed to be consistent)

tensorcircuit/backends/pytorch_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,6 @@ def expm(self, a: Tensor) -> Tensor:
253253
# it doesn't support complex numbers which is more severe issue.
254254
# see https://github.com/pytorch/pytorch/issues/9983
255255

256-
# see https://github.com/pytorch/pytorch/issues/9983
257-
258256
def sin(self, a: Tensor) -> Tensor:
259257
return torchlib.sin(a)
260258

tensorcircuit/backends/tensorflow_backend.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,8 @@ def _tensordot_tf(
7777
) -> Tensor:
7878
# Use TensorFlow's dtype promotion rules by converting both to a common dtype
7979
if a.dtype != b.dtype:
80-
# Find the result dtype by performing a dummy operation
81-
common_dtype = (
82-
tf.constant(0, dtype=a.dtype) + tf.constant(0, dtype=b.dtype)
83-
).dtype
80+
# Find the result dtype using TensorFlow's type promotion rules
81+
common_dtype = tf.experimental.numpy.result_type(a.dtype, b.dtype)
8482
a = tf.cast(a, common_dtype)
8583
b = tf.cast(b, common_dtype)
8684
return tf.tensordot(a, b, axes)

0 commit comments

Comments
 (0)