-
Notifications
You must be signed in to change notification settings - Fork 12
feat(lattice): Make lattice geometries differentiable and backend-agn… #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
9e01be8
9d22384
d71d4a1
bb65592
0ad707c
92bc8e4
efaee05
7063c6f
0660abf
589763e
daa3ff2
9575be5
d372f72
0b38522
04aca93
283e1fd
494a99b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| """ | ||
| Benchmark: Compare neighbor-building time between KDTree and distance-matrix | ||
| methods in CustomizeLattice for varying lattice sizes. | ||
| """ | ||
|
|
||
| import argparse | ||
| import csv | ||
| import time | ||
| from typing import Iterable, List, Tuple, Optional | ||
| import logging | ||
|
|
||
| import numpy as np | ||
|
|
||
| # Silence verbose infos from the library during benchmarks | ||
|
|
||
| logging.basicConfig(level=logging.WARNING) | ||
|
|
||
| from tensorcircuit.templates.lattice import CustomizeLattice | ||
|
|
||
|
|
||
| def run_once( | ||
| n: int, d: int, max_k: int, repeats: int, seed: int | ||
| ) -> Tuple[float, float]: | ||
| """Run one size point and return (time_kdtree, time_matrix).""" | ||
| rng = np.random.default_rng(seed) | ||
| ids = list(range(n)) | ||
|
|
||
| # Collect times for each repeat with different random coordinates | ||
| kdtree_times: List[float] = [] | ||
| matrix_times: List[float] = [] | ||
|
|
||
| for _ in range(repeats): | ||
| # Generate different coordinates for each repeat | ||
| coords = rng.random((n, d), dtype=float) | ||
| lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) | ||
|
|
||
| # KDTree path - single measurement | ||
| t0 = time.perf_counter() | ||
| lat._build_neighbors(max_k=max_k, use_kdtree=True) | ||
| kdtree_times.append(time.perf_counter() - t0) | ||
|
|
||
| # Distance-matrix path - single measurement | ||
| t0 = time.perf_counter() | ||
| lat._build_neighbors(max_k=max_k, use_kdtree=False) | ||
| matrix_times.append(time.perf_counter() - t0) | ||
|
|
||
| return float(np.mean(kdtree_times)), float(np.mean(matrix_times)) | ||
|
|
||
|
|
||
| def parse_sizes(s: str) -> List[int]: | ||
| return [int(x) for x in s.split(",") if x.strip()] | ||
|
|
||
|
|
||
| def format_row(n: int, t_kdtree: float, t_matrix: float) -> str: | ||
| speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") | ||
| return f"{n:>8} | {t_kdtree:>12.6f} | {t_matrix:>14.6f} | {speedup:>7.2f}x" | ||
|
|
||
|
|
||
| def main(argv: Optional[Iterable[str]] = None) -> int: | ||
| p = argparse.ArgumentParser(description="Neighbor-building time comparison") | ||
| p.add_argument( | ||
| "--sizes", | ||
| type=parse_sizes, | ||
| default=[128, 256, 512, 1024, 2048], | ||
| help="Comma-separated site counts to benchmark (default: 128,256,512,1024,2048)", | ||
| ) | ||
| p.add_argument( | ||
| "--dims", type=int, default=2, help="Lattice dimensionality (default: 2)" | ||
| ) | ||
| p.add_argument( | ||
| "--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)" | ||
| ) | ||
| p.add_argument( | ||
| "--repeats", type=int, default=5, help="Repeats per measurement (default: 5)" | ||
| ) | ||
| p.add_argument("--seed", type=int, default=42, help="PRNG seed (default: 42)") | ||
| p.add_argument("--csv", type=str, default="", help="Optional CSV output path") | ||
| args = p.parse_args(list(argv) if argv is not None else None) | ||
|
|
||
| print("=" * 74) | ||
| print( | ||
| f"Benchmark CustomizeLattice neighbor-building | dims={args.dims} max_k={args.max_k} repeats={args.repeats}" | ||
| ) | ||
| print("=" * 74) | ||
| print(f"{'N':>8} | {'KDTree(s)':>12} | {'DistMatrix(s)':>14} | {'Speedup':>7}") | ||
| print("-" * 74) | ||
|
|
||
| rows: List[Tuple[int, float, float]] = [] | ||
| for n in args.sizes: | ||
| t_kdtree, t_matrix = run_once(n, args.dims, args.max_k, args.repeats, args.seed) | ||
| rows.append((n, t_kdtree, t_matrix)) | ||
| print(format_row(n, t_kdtree, t_matrix)) | ||
|
|
||
| if args.csv: | ||
| with open(args.csv, "w", newline="", encoding="utf-8") as f: | ||
| w = csv.writer(f) | ||
| w.writerow(["N", "time_kdtree_s", "time_distance_matrix_s", "speedup"]) | ||
| for n, t_kdtree, t_matrix in rows: | ||
| speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") | ||
| w.writerow([n, f"{t_kdtree:.6f}", f"{t_matrix:.6f}", f"{speedup:.2f}"]) | ||
|
|
||
| print("-" * 74) | ||
| print(f"Saved CSV to: {args.csv}") | ||
|
|
||
| print("-" * 74) | ||
| print("Done.") | ||
| return 0 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| raise SystemExit(main()) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| """ | ||
| Lennard-Jones Potential Optimization Example | ||
| This script demonstrates how to use TensorCircuit's differentiable lattice geometries | ||
| to optimize crystal structure. It finds the equilibrium lattice constant that minimizes | ||
| the total Lennard-Jones potential energy of a 2D square lattice. | ||
| The optimization showcases the key Task 3 capability: making lattice parameters | ||
|
||
| differentiable for variational material design. | ||
| """ | ||
|
|
||
| import optax | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import numpy as np | ||
| import matplotlib.pyplot as plt | ||
| import tensorcircuit as tc | ||
|
|
||
|
|
||
| tc.set_dtype("float64") # Use tc for universal control | ||
| K = tc.set_backend("jax") | ||
|
|
||
|
|
||
| def calculate_potential(log_a, epsilon=0.5, sigma=1.0): | ||
| """ | ||
| Calculate the total Lennard-Jones potential energy for a given logarithm of the lattice constant (log_a). | ||
| This version creates the lattice inside the function to demonstrate truly differentiable geometry. | ||
| """ | ||
| lattice_constant = K.exp(log_a) | ||
|
|
||
| # Create lattice with the differentiable parameter | ||
| size = (4, 4) # Smaller size for demonstration | ||
| lattice = tc.templates.lattice.SquareLattice( | ||
| size, lattice_constant=lattice_constant, pbc=True | ||
| ) | ||
| d = lattice.distance_matrix | ||
|
|
||
| d_safe = K.where(d > 1e-9, d, K.convert_to_tensor(1e-9)) | ||
|
|
||
| term12 = K.power(sigma / d_safe, 12) | ||
| term6 = K.power(sigma / d_safe, 6) | ||
| potential_matrix = 4 * epsilon * (term12 - term6) | ||
|
|
||
| num_sites = lattice.num_sites | ||
| # Zero out self-interactions (diagonal elements) | ||
| eye_mask = K.eye(num_sites, dtype=potential_matrix.dtype) | ||
| potential_matrix = potential_matrix * (1 - eye_mask) | ||
|
|
||
| potential_energy = K.sum(potential_matrix) / 2.0 | ||
|
|
||
| return potential_energy | ||
|
|
||
|
|
||
| # Create value and grad function for optimization | ||
| value_and_grad_fun = K.jit(K.value_and_grad(calculate_potential)) | ||
|
|
||
| optimizer = optax.adam(learning_rate=0.01) | ||
|
|
||
| log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(1.1))) | ||
|
||
|
|
||
| opt_state = optimizer.init(log_a) | ||
|
|
||
| history = {"a": [], "energy": []} | ||
|
|
||
| print("Starting optimization of lattice constant...") | ||
| for i in range(200): | ||
| energy, grad = value_and_grad_fun(log_a) | ||
|
|
||
| history["a"].append(K.exp(log_a)) | ||
| history["energy"].append(energy) | ||
|
|
||
| updates, opt_state = optimizer.update(grad, opt_state) | ||
| log_a = optax.apply_updates(log_a, updates) | ||
|
|
||
| if (i + 1) % 20 == 0: | ||
| current_a = K.exp(log_a) | ||
| print( | ||
| f"Iteration {i+1}/200: Total Energy = {energy:.4f}, Lattice Constant = {current_a:.4f}" | ||
| ) | ||
|
|
||
| final_a = K.exp(log_a) | ||
| final_energy = calculate_potential(log_a) | ||
|
|
||
| print("\nOptimization finished!") | ||
| print(f"Final optimized lattice constant: {final_a:.6f}") | ||
| print(f"Corresponding minimum total energy: {final_energy:.6f}") | ||
|
|
||
| # Vectorized calculation for the potential curve | ||
| a_vals = np.linspace(0.8, 1.5, 200) | ||
| log_a_vals = K.log(K.convert_to_tensor(a_vals)) | ||
|
|
||
| # Use vmap to create a vectorized version of the potential function | ||
| vmap_potential = K.vmap(lambda la: calculate_potential(la)) | ||
| potential_curve = vmap_potential(log_a_vals) | ||
|
|
||
| plt.figure(figsize=(10, 6)) | ||
| plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue") | ||
| plt.scatter( | ||
| history["a"], | ||
| history["energy"], | ||
| color="red", | ||
| s=20, | ||
| zorder=5, | ||
| label="Optimization Steps", | ||
| ) | ||
| plt.scatter( | ||
| final_a, | ||
| final_energy, | ||
| color="green", | ||
| s=100, | ||
| zorder=6, | ||
| marker="*", | ||
| label="Final Optimized Point", | ||
| ) | ||
|
|
||
| plt.title("Lennard-Jones Potential Optimization") | ||
| plt.xlabel("Lattice Constant (a)") | ||
| plt.ylabel("Total Potential Energy") | ||
| plt.legend() | ||
| plt.grid(True) | ||
| plt.show() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -596,6 +596,82 @@ def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor: | |
| "Backend '{}' has not implemented `argsort`.".format(self.name) | ||
| ) | ||
|
|
||
| def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor: | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Sort a tensor along the given axis. | ||
|
|
||
| :param a: [description] | ||
| :type a: Tensor | ||
| :param axis: [description], defaults to -1 | ||
| :type axis: int, optional | ||
| :return: [description] | ||
| :rtype: Tensor | ||
| """ | ||
| raise NotImplementedError( | ||
| "Backend '{}' has not implemented `sort`.".format(self.name) | ||
| ) | ||
|
|
||
| def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: | ||
| """ | ||
| Test whether all array elements along a given axis evaluate to True. | ||
|
|
||
| :param a: Input tensor | ||
| :type a: Tensor | ||
| :param axis: Axis or axes along which a logical AND reduction is performed, | ||
| defaults to None | ||
| :type axis: Optional[Sequence[int]], optional | ||
| :return: A new boolean or tensor resulting from the AND reduction | ||
| :rtype: Tensor | ||
| """ | ||
| raise NotImplementedError( | ||
| "Backend '{}' has not implemented `all`.".format(self.name) | ||
| ) | ||
|
|
||
| def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: | ||
| """ | ||
| Return coordinate matrices from coordinate vectors. | ||
|
||
|
|
||
| :param args: coordinate vectors | ||
| :type args: Any | ||
| :param kwargs: keyword arguments for meshgrid, typically includes 'indexing' | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing). | ||
| - 'ij': matrix indexing, first dimension corresponds to rows (default) | ||
| - 'xy': Cartesian indexing, first dimension corresponds to columns | ||
| Example: | ||
| >>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy') | ||
| Shapes: | ||
| - x.shape == (2, 2) # rows correspond to y vector length | ||
| - y.shape == (2, 2) | ||
| Values: | ||
| x = [[0, 1], | ||
| [0, 1]] | ||
| y = [[0, 0], | ||
| [2, 2]] | ||
| :type kwargs: Any | ||
| :return: list of coordinate matrices | ||
| :rtype: Any | ||
| """ | ||
| raise NotImplementedError( | ||
| "Backend '{}' has not implemented `meshgrid`.".format(self.name) | ||
| ) | ||
|
|
||
| def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor: | ||
| """ | ||
| Expand the shape of a tensor. | ||
| Insert a new axis that will appear at the `axis` position in the expanded | ||
| tensor shape. | ||
|
|
||
| :param a: Input tensor | ||
| :type a: Tensor | ||
| :param axis: Position in the expanded axes where the new axis is placed | ||
| :type axis: int | ||
| :return: Output tensor with the number of dimensions increased by one. | ||
| :rtype: Tensor | ||
| """ | ||
| raise NotImplementedError( | ||
| "Backend '{}' has not implemented `expand_dims`.".format(self.name) | ||
| ) | ||
|
|
||
| def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: | ||
| """ | ||
| Find the unique elements and their corresponding counts of the given tensor ``a``. | ||
|
|
@@ -733,6 +809,21 @@ def cast(self: Any, a: Tensor, dtype: str) -> Tensor: | |
| "Backend '{}' has not implemented `cast`.".format(self.name) | ||
| ) | ||
|
|
||
| def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor: | ||
| """ | ||
| Convert input to tensor. | ||
|
|
||
| :param a: input data to be converted | ||
| :type a: Tensor | ||
| :param dtype: target dtype, optional | ||
| :type dtype: Optional[str] | ||
| :return: converted tensor | ||
| :rtype: Tensor | ||
| """ | ||
| raise NotImplementedError( | ||
| "Backend '{}' has not implemented `convert_to_tensor`.".format(self.name) | ||
| ) | ||
|
|
||
| def mod(self: Any, x: Tensor, y: Tensor) -> Tensor: | ||
| """ | ||
| Compute y-mod of x (negative number behavior is not guaranteed to be consistent) | ||
|
|
@@ -1404,6 +1495,28 @@ def cond( | |
| "Backend '{}' has not implemented `cond`.".format(self.name) | ||
| ) | ||
|
|
||
| def where( | ||
| self: Any, | ||
| condition: Tensor, | ||
| x: Optional[Tensor] = None, | ||
| y: Optional[Tensor] = None, | ||
| ) -> Tensor: | ||
| """ | ||
| Return a tensor of elements selected from either x or y, depending on condition. | ||
|
|
||
| :param condition: Where True, yield x, otherwise yield y. | ||
| :type condition: Tensor (bool) | ||
| :param x: Values from which to choose when condition is True. | ||
| :type x: Tensor | ||
| :param y: Values from which to choose when condition is False. | ||
| :type y: Tensor | ||
| :return: A tensor with elements from x where condition is True, and y otherwise. | ||
| :rtype: Tensor | ||
| """ | ||
| raise NotImplementedError( | ||
| "Backend '{}' has not implemented `where`.".format(self.name) | ||
| ) | ||
|
|
||
| def switch( | ||
| self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]] | ||
| ) -> Tensor: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please delete this file, and ensure the previous lattice neighbor example
lattice_neighbor_benchmark.pyis doing correct, i.e. compare kdtree and the baseline