Skip to content

Commit 9575be5

Browse files
committed
fix according to the review
1 parent daa3ff2 commit 9575be5

File tree

6 files changed

+304
-143
lines changed

6 files changed

+304
-143
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
Benchmark: Compare neighbor-building time between KDTree and distance-matrix
3+
methods in CustomizeLattice for varying lattice sizes.
4+
"""
5+
6+
import argparse
7+
import csv
8+
import time
9+
from typing import Iterable, List, Tuple, Optional
10+
import logging
11+
12+
import numpy as np
13+
14+
# Silence verbose infos from the library during benchmarks
15+
16+
logging.basicConfig(level=logging.WARNING)
17+
18+
from tensorcircuit.templates.lattice import CustomizeLattice
19+
20+
21+
def _timeit(fn, repeats: int) -> float:
22+
"""Return average wall time (seconds) over repeats for calling fn()."""
23+
times: List[float] = []
24+
for _ in range(repeats):
25+
t0 = time.perf_counter()
26+
fn()
27+
times.append(time.perf_counter() - t0)
28+
return float(np.mean(times))
29+
30+
31+
def _gen_coords(n: int, d: int, seed: int) -> np.ndarray:
32+
rng = np.random.default_rng(seed)
33+
return rng.random((n, d), dtype=float)
34+
35+
36+
def run_once(
37+
n: int, d: int, max_k: int, repeats: int, seed: int
38+
) -> Tuple[float, float]:
39+
"""Run one size point and return (time_kdtree, time_matrix)."""
40+
coords = _gen_coords(n, d, seed)
41+
ids = list(range(n))
42+
lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords)
43+
44+
# KDTree path
45+
t_kdtree = _timeit(
46+
lambda: lat._build_neighbors(max_k=max_k, use_kdtree=True), repeats
47+
)
48+
49+
# Distance-matrix path (fully differentiable)
50+
t_matrix = _timeit(
51+
lambda: lat._build_neighbors(max_k=max_k, use_kdtree=False), repeats
52+
)
53+
54+
return t_kdtree, t_matrix
55+
56+
57+
def parse_sizes(s: str) -> List[int]:
58+
return [int(x) for x in s.split(",") if x.strip()]
59+
60+
61+
def format_row(n: int, t_kdtree: float, t_matrix: float) -> str:
62+
speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf")
63+
return f"{n:>8} | {t_kdtree:>12.6f} | {t_matrix:>14.6f} | {speedup:>7.2f}x"
64+
65+
66+
def main(argv: Optional[Iterable[str]] = None) -> int:
67+
p = argparse.ArgumentParser(description="Neighbor-building time comparison")
68+
p.add_argument(
69+
"--sizes",
70+
type=parse_sizes,
71+
default=[128, 256, 512, 1024, 2048],
72+
help="Comma-separated site counts to benchmark (default: 128,256,512,1024,2048)",
73+
)
74+
p.add_argument(
75+
"--dims", type=int, default=2, help="Lattice dimensionality (default: 2)"
76+
)
77+
p.add_argument(
78+
"--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)"
79+
)
80+
p.add_argument(
81+
"--repeats", type=int, default=5, help="Repeats per measurement (default: 5)"
82+
)
83+
p.add_argument("--seed", type=int, default=42, help="PRNG seed (default: 42)")
84+
p.add_argument("--csv", type=str, default="", help="Optional CSV output path")
85+
args = p.parse_args(list(argv) if argv is not None else None)
86+
87+
print("=" * 74)
88+
print(
89+
f"Benchmark CustomizeLattice neighbor-building | dims={args.dims} max_k={args.max_k} repeats={args.repeats}"
90+
)
91+
print("=" * 74)
92+
print(f"{'N':>8} | {'KDTree(s)':>12} | {'DistMatrix(s)':>14} | {'Speedup':>7}")
93+
print("-" * 74)
94+
95+
rows: List[Tuple[int, float, float]] = []
96+
for n in args.sizes:
97+
t_kdtree, t_matrix = run_once(n, args.dims, args.max_k, args.repeats, args.seed)
98+
rows.append((n, t_kdtree, t_matrix))
99+
print(format_row(n, t_kdtree, t_matrix))
100+
101+
if args.csv:
102+
with open(args.csv, "w", newline="", encoding="utf-8") as f:
103+
w = csv.writer(f)
104+
w.writerow(["N", "time_kdtree_s", "time_distance_matrix_s", "speedup"])
105+
for n, t_kdtree, t_matrix in rows:
106+
speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf")
107+
w.writerow([n, f"{t_kdtree:.6f}", f"{t_matrix:.6f}", f"{speedup:.2f}"])
108+
109+
print("-" * 74)
110+
print(f"Saved CSV to: {args.csv}")
111+
112+
print("-" * 74)
113+
print("Done.")
114+
return 0
115+
116+
117+
if __name__ == "__main__":
118+
raise SystemExit(main())

examples/lennard_jones_optimization.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0):
6767
history["a"].append(K.exp(log_a))
6868
history["energy"].append(energy)
6969

70-
# (Removed previously added blanket NaN guard per reviewer request to keep example minimal.)
71-
7270
updates, opt_state = optimizer.update(grad, opt_state)
7371
log_a = optax.apply_updates(log_a, updates)
7472

@@ -120,5 +118,3 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0):
120118
plt.legend()
121119
plt.grid(True)
122120
plt.show()
123-
else:
124-
print("\nOptimization failed. Final energy is NaN.")

tensorcircuit/backends/abstract_backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,8 @@ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
631631
"""
632632
Return coordinate matrices from coordinate vectors.
633633
634-
:param args: coordinate vectors
635-
:type args: Any
634+
:param args: coordinate vectors
635+
:type args: Any
636636
:param kwargs: keyword arguments for meshgrid, typically includes 'indexing'
637637
which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing).
638638
- 'ij': matrix indexing, first dimension corresponds to rows (default)
@@ -647,9 +647,9 @@ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
647647
[0, 1]]
648648
y = [[0, 0],
649649
[2, 2]]
650-
:type kwargs: Any
651-
:return: list of coordinate matrices
652-
:rtype: Any
650+
:type kwargs: Any
651+
:return: list of coordinate matrices
652+
:rtype: Any
653653
"""
654654
raise NotImplementedError(
655655
"Backend '{}' has not implemented `meshgrid`.".format(self.name)

tensorcircuit/backends/numpy_backend.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,6 @@ def kron(self, a: Tensor, b: Tensor) -> Tensor:
137137
return np.kron(a, b)
138138

139139
def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
140-
"""
141-
Backend-agnostic meshgrid function.
142-
"""
143140
return np.meshgrid(*args, **kwargs)
144141

145142
def dtype(self, a: Tensor) -> str:

0 commit comments

Comments
 (0)