Skip to content

Commit 2aad7be

Browse files
add u1 helper functions in quantum
1 parent 872c957 commit 2aad7be

File tree

4 files changed

+127
-9
lines changed

4 files changed

+127
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
### Added
66

7-
- Add jax jitted function load/save utilities in experimental module
7+
- Add jax jitted function load/save utilities in experimental.py
88

99
- Add `circuit.to_openqasm_file` function for compatibility of qiskit>1
1010

1111
- Add `tc.cite()` to get citation information
1212

13+
- Add `u1_inds`, `u1_mask`, `u1_project`, and `u1_enlarge` functions in quantum.py as utilities in charged conservation systems
14+
1315
### Fixed
1416

1517
- Fix customized jax eigh operator by noting the return is a namedtuple

tensorcircuit/backends/jax_backend.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -609,14 +609,16 @@ def f_jax(*args: Any, **kws: Any) -> Any:
609609
return carry
610610

611611
def scatter(self, operand: Tensor, indices: Tensor, updates: Tensor) -> Tensor:
612-
rank = len(operand.shape)
613-
dnums = libjax.lax.ScatterDimensionNumbers(
614-
update_window_dims=(),
615-
inserted_window_dims=tuple([i for i in range(rank)]),
616-
scatter_dims_to_operand_dims=tuple([i for i in range(rank)]),
617-
)
618-
r = libjax.lax.scatter(operand, indices, updates, dnums)
619-
return r
612+
updates = jnp.reshape(updates, indices.shape)
613+
return operand.at[indices].set(updates)
614+
# rank = len(operand.shape)
615+
# dnums = libjax.lax.ScatterDimensionNumbers(
616+
# update_window_dims=(),
617+
# inserted_window_dims=tuple([i for i in range(rank)]),
618+
# scatter_dims_to_operand_dims=tuple([i for i in range(rank)]),
619+
# )
620+
# r = libjax.lax.scatter(operand, indices, updates, dnums)
621+
# return r
620622

621623
def coo_sparse_matrix(
622624
self, indices: Tensor, values: Tensor, shape: Tensor

tensorcircuit/quantum.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# pylint: disable=invalid-name
1212

1313
import logging
14+
import math
1415
import os
1516
from functools import partial, reduce
1617
from operator import matmul, mul, or_
@@ -1160,6 +1161,96 @@ def quimb2qop(qb_mpo: Any) -> QuOperator:
11601161
return qop
11611162

11621163

1164+
def u1_inds(n: int, m: int) -> Tensor:
1165+
"""
1166+
Generate all the combination index of m down spins in n sites.
1167+
1168+
.. code-block:: python
1169+
1170+
print(u1_inds(5, 1))
1171+
# [1, 2, 4, 8, 16]
1172+
1173+
1174+
:param n: number of total sites
1175+
:type n: int
1176+
:param m: number of down spins (1 in 0, 1)
1177+
:type m: int
1178+
:return: index tensor
1179+
:rtype: Tensor
1180+
"""
1181+
# m down spins
1182+
num_combinations = math.comb(n, m)
1183+
inds = np.zeros([num_combinations], dtype="int64")
1184+
if m == 0:
1185+
inds[0] = 0
1186+
return inds
1187+
combination = (1 << m) - 1
1188+
1189+
for i in range(num_combinations):
1190+
inds[i] = combination
1191+
1192+
# Find the next combination using Gosper's Hack
1193+
u = combination & -combination
1194+
v = u + combination
1195+
combination = v + (((v ^ combination) // u) >> 2)
1196+
return backend.convert_to_tensor(inds)
1197+
1198+
1199+
def u1_mask(n: int, m: int) -> Tensor:
1200+
"""
1201+
Return the 1d array of size 2**n filled with zero,
1202+
one only in elements corresponding to the m down spins
1203+
1204+
:param n: number of total sites
1205+
:type n: int
1206+
:param m: number of down spins (1 in 0, 1)
1207+
:type m: int
1208+
:return: _description_
1209+
:rtype: Tensor
1210+
"""
1211+
inds = u1_inds(n, m)
1212+
m = backend.scatter(
1213+
backend.zeros([2**n]),
1214+
backend.reshape(inds, [-1, 1]),
1215+
backend.ones([math.comb(n, m)]),
1216+
)
1217+
return m
1218+
1219+
1220+
def u1_project(s: Tensor, n: int, m: int) -> Tensor:
1221+
"""
1222+
Project a state s to the subspace with m down spins in n sites
1223+
1224+
:param s: input state of size 2**n
1225+
:type s: Tensor
1226+
:param n: number of total sites
1227+
:type n: int
1228+
:param m: number of down spins (1 in 0, 1)
1229+
:type m: int
1230+
:return: projected state of size C_n^m
1231+
:rtype: Tensor
1232+
"""
1233+
return backend.gather1d(s, u1_inds(n, m))
1234+
1235+
1236+
def u1_enlarge(s: Tensor, n: int, m: int) -> Tensor:
1237+
"""
1238+
Enlarge a state s in the subspace with m down spins in n sites to
1239+
the full Hilbert space wavefunction of size 2**n
1240+
1241+
:param s: input state of size C_n^m
1242+
:type s: Tensor
1243+
:param n: number of total sites
1244+
:type n: int
1245+
:param m: number of down spins (1 in 0, 1)
1246+
:type m: int
1247+
:return: enlarged state of size 2**n
1248+
:rtype: Tensor
1249+
"""
1250+
inds = u1_inds(n, m)
1251+
return backend.scatter(backend.zeros([2**n]), backend.reshape(inds, [-1, 1]), s)
1252+
1253+
11631254
def heisenberg_hamiltonian(
11641255
g: Graph,
11651256
hzz: float = 1.0,

tests/test_quantum.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,26 @@ def test_reduced_wavefunction(backend):
524524
c1.h(0)
525525
c1.cnot(0, 1)
526526
np.testing.assert_allclose(s1, c1.state(), atol=1e-5)
527+
528+
529+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
530+
def test_u1_mask(backend):
531+
g = tc.templates.graphs.Line1D(8)
532+
sumz = tc.quantum.heisenberg_hamiltonian(g, hzz=0, hxx=0, hyy=0, hz=1)
533+
for i in range(9):
534+
s = tc.quantum.u1_mask(8, i)
535+
s /= tc.backend.norm(s)
536+
c = tc.Circuit(8, inputs=s)
537+
zexp = tc.templates.measurements.operator_expectation(c, sumz)
538+
np.testing.assert_allclose(zexp, 8 - 2 * i, atol=1e-6)
539+
540+
541+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
542+
def test_u1_project(backend):
543+
c = tc.Circuit(8)
544+
c.x([0, 2, 4])
545+
c.exp1(0, 1, unitary=tc.gates._swap_matrix, theta=0.6)
546+
s = c.state()
547+
s1 = tc.quantum.u1_project(s, 8, 3)
548+
assert s1.shape[-1] == 56
549+
np.testing.assert_allclose(tc.quantum.u1_enlarge(s1, 8, 3), s)

0 commit comments

Comments
 (0)