Skip to content

Commit be5bf21

Browse files
mpscircuit now try first convert quvector directly into mps if possible
1 parent ba998e9 commit be5bf21

File tree

5 files changed

+68
-11
lines changed

5 files changed

+68
-11
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
- Add `ode_evol_local` and `ode_evol_global` methods for local and global ODE evolution.
2020

21+
- Add transformation method between tensornetwork, quimb, tenpy and QuOperator in tc-ng including `qop2tenpy`, `qop2quimb`, `qop2tn`, `tenpy2qop`, support both MPS and MPO formats.
22+
2123
### Fixed
2224

2325
- Fixed `one_hot` in numpy backend.
@@ -26,6 +28,8 @@
2628

2729
- Fix potential np.matrix return from `PaulistringSum2Dense`.
2830

31+
- `MPSCircuit` now will first try to transform `QuVector` input to tensors directly instead of evaluating it to dense wavefunction first.
32+
2933
### Changed
3034

3135
- The order of arguments of `tc.timeevol.ed_evol` are changed for consistent interface with other evolution methods.

examples/tenpy_sz_convention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120

121121
print("--- Scenario 3: Correcting XXZChain DMRG state with X-gates ---")
122122

123-
L = 10
123+
L = 30
124124
xxz_model_params = {"L": L, "bc_MPS": "finite", "Jxx": 1.0, "Jz": 1.0, "conserve": None}
125125
xxz_M = XXZChain(xxz_model_params)
126126
psi0_xxz = MPS.from_product_state(

tensorcircuit/mpscircuit.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,24 @@
44

55
# pylint: disable=invalid-name
66

7-
from functools import reduce
7+
from functools import reduce, partial
88
from typing import Any, List, Optional, Sequence, Tuple, Dict, Union
99
from copy import copy
10+
import logging
1011

1112
import numpy as np
1213
import tensornetwork as tn
13-
from tensorcircuit.quantum import QuOperator, QuVector
1414

1515
from . import gates
1616
from .cons import backend, npdtype, contractor, rdtypestr, dtypestr
17+
from .quantum import QuOperator, QuVector, extract_tensors_from_qop
1718
from .mps_base import FiniteMPS
1819
from .abstractcircuit import AbstractCircuit
20+
from .utils import arg_alias
1921

2022
Gate = gates.Gate
2123
Tensor = Any
24+
logger = logging.getLogger(__name__)
2225

2326

2427
def split_tensor(
@@ -77,6 +80,10 @@ class MPSCircuit(AbstractCircuit):
7780

7881
is_mps = True
7982

83+
@partial(
84+
arg_alias,
85+
alias_dict={"wavefunction": ["inputs"]},
86+
)
8087
def __init__(
8188
self,
8289
nqubits: int,
@@ -118,8 +125,19 @@ def __init__(
118125
), "tensors and wavefunction cannot be used at input simutaneously"
119126
# TODO(@SUSYUSTC): find better way to address QuVector
120127
if isinstance(wavefunction, QuVector):
121-
wavefunction = wavefunction.eval()
122-
tensors = self.wavefunction_to_tensors(wavefunction, split=self.split)
128+
try:
129+
nodes, is_mps, _ = extract_tensors_from_qop(wavefunction)
130+
if not is_mps:
131+
raise ValueError("wavefunction is not a valid MPS")
132+
tensors = [node.tensor for node in nodes]
133+
except ValueError as e:
134+
logger.warning(repr(e))
135+
wavefunction = wavefunction.eval()
136+
tensors = self.wavefunction_to_tensors(
137+
wavefunction, split=self.split
138+
)
139+
else: # full wavefunction
140+
tensors = self.wavefunction_to_tensors(wavefunction, split=self.split)
123141
assert len(tensors) == nqubits
124142
self._mps = FiniteMPS(tensors, canonicalize=False)
125143
self._mps.center_position = 0

tensorcircuit/quantum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1181,7 +1181,6 @@ def extract_tensors_from_qop(qop: QuOperator) -> Tuple[List[Node], bool, int]:
11811181
# Find endpoint nodes
11821182
endpoint_nodes = set()
11831183
physical_edges = set(qop.out_edges) if is_mps else set(qop.in_edges + qop.out_edges)
1184-
11851184
if is_mps:
11861185
rank_2_nodes = {node for node in nodes_for_sorting if len(node.edges) == 2}
11871186
if len(rank_2_nodes) == 2:
@@ -1340,6 +1339,7 @@ def qop2tenpy(qop: QuOperator) -> Any:
13401339
- Cyclic boundary conditions NOT supported
13411340
13421341
:param qop: The corresponding state/operator as a QuOperator.
1342+
:type qop: QuOperator
13431343
:return: MPO or MPS object from the TeNPy package.
13441344
:rtype: Union[tenpy.networks.mpo.MPO, tenpy.networks.mps.MPS]
13451345
"""
@@ -1480,6 +1480,7 @@ def qop2quimb(qop: QuOperator) -> Any:
14801480
- Cyclic boundary conditions NOT supported
14811481
14821482
:param qop: MPO in the form of QuOperator
1483+
:type qop: QuOperator
14831484
:return: MPO in the form of Quimb package
14841485
:rtype: quimb.tensor.tensor_gen.MatrixProductOperator
14851486
"""

tests/test_mpscircuit.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def test_circuits(backend, dtype):
309309
do_test_measure(circuits)
310310

311311

312+
# TODO(@refraction-ray): fails (lf("jaxb"), lf("highp"))
312313
@pytest.mark.parametrize("backend, dtype", [(lf("tfb"), lf("highp"))])
313314
def test_circuits_jit(backend, dtype):
314315
def expec(params):
@@ -325,15 +326,48 @@ def expec(params):
325326
expec_vg_jit = tc.backend.jit(expec_vg)
326327
exp = expec(params)
327328
exp_jit, exp_grad_jit = expec_vg_jit(params)
328-
dir = tc.backend.convert_to_tensor(np.array([1.0, 2.0, 3.0], dtype=tc.dtypestr))
329+
dir_ = tc.backend.convert_to_tensor(np.array([1.0, 2.0, 3.0], dtype=tc.dtypestr))
329330
epsilon = 1e-6
330-
exp_p = expec(params + dir * epsilon)
331-
exp_m = expec(params - dir * epsilon)
331+
exp_p = expec(params + dir_ * epsilon)
332+
exp_m = expec(params - dir_ * epsilon)
332333
exp_grad_dir_numerical = (exp_p - exp_m) / (epsilon * 2)
333-
exp_grad_dir_jit = tc.backend.real(tc.backend.sum(exp_grad_jit * dir))
334+
exp_grad_dir_jit = tc.backend.real(tc.backend.sum(exp_grad_jit * dir_))
335+
np.testing.assert_allclose(exp, exp_jit, atol=1e-10)
334336
np.testing.assert_allclose(
335-
tc.backend.numpy(exp), tc.backend.numpy(exp_jit), atol=1e-10
337+
tc.backend.numpy(exp_grad_dir_numerical),
338+
tc.backend.numpy(exp_grad_dir_jit),
339+
atol=1e-6,
336340
)
341+
342+
343+
@pytest.mark.parametrize(
344+
"backend, dtype", [(lf("tfb"), lf("highp")), (lf("jaxb"), lf("highp"))]
345+
)
346+
def test_simple_circuits_ad(backend, dtype):
347+
def expec(params):
348+
mps = tc.MPSCircuit(N, split=split)
349+
mps.rx(0, theta=params[0])
350+
mps.cx(0, 1)
351+
mps.cx(1, 2)
352+
mps.ry(2, theta=params[1])
353+
mps.rzz(1, 3, theta=params[2])
354+
x = [0, 2]
355+
z = [1]
356+
exp = mps.expectation_ps(x=x, z=z)
357+
return tc.backend.real(exp)
358+
359+
params = tc.backend.ones((3,), dtype=tc.dtypestr)
360+
expec_vg = tc.backend.value_and_grad(expec)
361+
expec_vg_jit = tc.backend.jit(expec_vg)
362+
exp = expec(params)
363+
exp_jit, exp_grad_jit = expec_vg_jit(params)
364+
dir_ = tc.backend.convert_to_tensor(np.array([1.0, 2.0, 3.0], dtype=tc.dtypestr))
365+
epsilon = 1e-6
366+
exp_p = expec(params + dir_ * epsilon)
367+
exp_m = expec(params - dir_ * epsilon)
368+
exp_grad_dir_numerical = (exp_p - exp_m) / (epsilon * 2)
369+
exp_grad_dir_jit = tc.backend.real(tc.backend.sum(exp_grad_jit * dir_))
370+
np.testing.assert_allclose(exp, exp_jit, atol=1e-6)
337371
np.testing.assert_allclose(
338372
tc.backend.numpy(exp_grad_dir_numerical),
339373
tc.backend.numpy(exp_grad_dir_jit),

0 commit comments

Comments
 (0)