Skip to content

Commit d372f72

Browse files
committed
update lattice_neighbor_time_compare.py to enhance the accuracy
1 parent 9575be5 commit d372f72

File tree

1 file changed

+21
-28
lines changed

1 file changed

+21
-28
lines changed

examples/lattice_neighbor_time_compare.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,40 +18,33 @@
1818
from tensorcircuit.templates.lattice import CustomizeLattice
1919

2020

21-
def _timeit(fn, repeats: int) -> float:
22-
"""Return average wall time (seconds) over repeats for calling fn()."""
23-
times: List[float] = []
24-
for _ in range(repeats):
25-
t0 = time.perf_counter()
26-
fn()
27-
times.append(time.perf_counter() - t0)
28-
return float(np.mean(times))
29-
30-
31-
def _gen_coords(n: int, d: int, seed: int) -> np.ndarray:
32-
rng = np.random.default_rng(seed)
33-
return rng.random((n, d), dtype=float)
34-
35-
3621
def run_once(
3722
n: int, d: int, max_k: int, repeats: int, seed: int
3823
) -> Tuple[float, float]:
3924
"""Run one size point and return (time_kdtree, time_matrix)."""
40-
coords = _gen_coords(n, d, seed)
25+
rng = np.random.default_rng(seed)
4126
ids = list(range(n))
42-
lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords)
43-
44-
# KDTree path
45-
t_kdtree = _timeit(
46-
lambda: lat._build_neighbors(max_k=max_k, use_kdtree=True), repeats
47-
)
48-
49-
# Distance-matrix path (fully differentiable)
50-
t_matrix = _timeit(
51-
lambda: lat._build_neighbors(max_k=max_k, use_kdtree=False), repeats
52-
)
27+
28+
# Collect times for each repeat with different random coordinates
29+
kdtree_times: List[float] = []
30+
matrix_times: List[float] = []
31+
32+
for i in range(repeats):
33+
# Generate different coordinates for each repeat
34+
coords = rng.random((n, d), dtype=float)
35+
lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords)
36+
37+
# KDTree path - single measurement
38+
t0 = time.perf_counter()
39+
lat._build_neighbors(max_k=max_k, use_kdtree=True)
40+
kdtree_times.append(time.perf_counter() - t0)
41+
42+
# Distance-matrix path - single measurement
43+
t0 = time.perf_counter()
44+
lat._build_neighbors(max_k=max_k, use_kdtree=False)
45+
matrix_times.append(time.perf_counter() - t0)
5346

54-
return t_kdtree, t_matrix
47+
return float(np.mean(kdtree_times)), float(np.mean(matrix_times))
5548

5649

5750
def parse_sizes(s: str) -> List[int]:

0 commit comments

Comments
 (0)