|
| 1 | +""" |
| 2 | +Benchmark: Compare neighbor-building time between KDTree and distance-matrix |
| 3 | +methods in CustomizeLattice for varying lattice sizes. |
| 4 | +""" |
| 5 | + |
| 6 | +import argparse |
| 7 | +import csv |
| 8 | +import time |
| 9 | +from typing import Iterable, List, Tuple, Optional |
| 10 | +import logging |
| 11 | + |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +# Silence verbose infos from the library during benchmarks |
| 15 | + |
| 16 | +logging.basicConfig(level=logging.WARNING) |
| 17 | + |
| 18 | +from tensorcircuit.templates.lattice import CustomizeLattice |
| 19 | + |
| 20 | + |
| 21 | +def _timeit(fn, repeats: int) -> float: |
| 22 | + """Return average wall time (seconds) over repeats for calling fn().""" |
| 23 | + times: List[float] = [] |
| 24 | + for _ in range(repeats): |
| 25 | + t0 = time.perf_counter() |
| 26 | + fn() |
| 27 | + times.append(time.perf_counter() - t0) |
| 28 | + return float(np.mean(times)) |
| 29 | + |
| 30 | + |
| 31 | +def _gen_coords(n: int, d: int, seed: int) -> np.ndarray: |
| 32 | + rng = np.random.default_rng(seed) |
| 33 | + return rng.random((n, d), dtype=float) |
| 34 | + |
| 35 | + |
| 36 | +def run_once( |
| 37 | + n: int, d: int, max_k: int, repeats: int, seed: int |
| 38 | +) -> Tuple[float, float]: |
| 39 | + """Run one size point and return (time_kdtree, time_matrix).""" |
| 40 | + coords = _gen_coords(n, d, seed) |
| 41 | + ids = list(range(n)) |
| 42 | + lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) |
| 43 | + |
| 44 | + # KDTree path |
| 45 | + t_kdtree = _timeit( |
| 46 | + lambda: lat._build_neighbors(max_k=max_k, use_kdtree=True), repeats |
| 47 | + ) |
| 48 | + |
| 49 | + # Distance-matrix path (fully differentiable) |
| 50 | + t_matrix = _timeit( |
| 51 | + lambda: lat._build_neighbors(max_k=max_k, use_kdtree=False), repeats |
| 52 | + ) |
| 53 | + |
| 54 | + return t_kdtree, t_matrix |
| 55 | + |
| 56 | + |
| 57 | +def parse_sizes(s: str) -> List[int]: |
| 58 | + return [int(x) for x in s.split(",") if x.strip()] |
| 59 | + |
| 60 | + |
| 61 | +def format_row(n: int, t_kdtree: float, t_matrix: float) -> str: |
| 62 | + speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") |
| 63 | + return f"{n:>8} | {t_kdtree:>12.6f} | {t_matrix:>14.6f} | {speedup:>7.2f}x" |
| 64 | + |
| 65 | + |
| 66 | +def main(argv: Optional[Iterable[str]] = None) -> int: |
| 67 | + p = argparse.ArgumentParser(description="Neighbor-building time comparison") |
| 68 | + p.add_argument( |
| 69 | + "--sizes", |
| 70 | + type=parse_sizes, |
| 71 | + default=[128, 256, 512, 1024, 2048], |
| 72 | + help="Comma-separated site counts to benchmark (default: 128,256,512,1024,2048)", |
| 73 | + ) |
| 74 | + p.add_argument( |
| 75 | + "--dims", type=int, default=2, help="Lattice dimensionality (default: 2)" |
| 76 | + ) |
| 77 | + p.add_argument( |
| 78 | + "--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)" |
| 79 | + ) |
| 80 | + p.add_argument( |
| 81 | + "--repeats", type=int, default=5, help="Repeats per measurement (default: 5)" |
| 82 | + ) |
| 83 | + p.add_argument("--seed", type=int, default=42, help="PRNG seed (default: 42)") |
| 84 | + p.add_argument("--csv", type=str, default="", help="Optional CSV output path") |
| 85 | + args = p.parse_args(list(argv) if argv is not None else None) |
| 86 | + |
| 87 | + print("=" * 74) |
| 88 | + print( |
| 89 | + f"Benchmark CustomizeLattice neighbor-building | dims={args.dims} max_k={args.max_k} repeats={args.repeats}" |
| 90 | + ) |
| 91 | + print("=" * 74) |
| 92 | + print(f"{'N':>8} | {'KDTree(s)':>12} | {'DistMatrix(s)':>14} | {'Speedup':>7}") |
| 93 | + print("-" * 74) |
| 94 | + |
| 95 | + rows: List[Tuple[int, float, float]] = [] |
| 96 | + for n in args.sizes: |
| 97 | + t_kdtree, t_matrix = run_once(n, args.dims, args.max_k, args.repeats, args.seed) |
| 98 | + rows.append((n, t_kdtree, t_matrix)) |
| 99 | + print(format_row(n, t_kdtree, t_matrix)) |
| 100 | + |
| 101 | + if args.csv: |
| 102 | + with open(args.csv, "w", newline="", encoding="utf-8") as f: |
| 103 | + w = csv.writer(f) |
| 104 | + w.writerow(["N", "time_kdtree_s", "time_distance_matrix_s", "speedup"]) |
| 105 | + for n, t_kdtree, t_matrix in rows: |
| 106 | + speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") |
| 107 | + w.writerow([n, f"{t_kdtree:.6f}", f"{t_matrix:.6f}", f"{speedup:.2f}"]) |
| 108 | + |
| 109 | + print("-" * 74) |
| 110 | + print(f"Saved CSV to: {args.csv}") |
| 111 | + |
| 112 | + print("-" * 74) |
| 113 | + print("Done.") |
| 114 | + return 0 |
| 115 | + |
| 116 | + |
| 117 | +if __name__ == "__main__": |
| 118 | + raise SystemExit(main()) |
0 commit comments