Skip to content

Commit ba998e9

Browse files
committed
Optimize tenpy2qop function
1 parent 2cf3eef commit ba998e9

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

tensorcircuit/quantum.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,37 +1272,39 @@ def tenpy2qop(tenpy_obj: Any) -> QuOperator:
12721272

12731273
nodes = []
12741274
if is_mpo:
1275-
vr_label, vl_label = "wR", "wL"
1276-
original_tensors = [W.to_ndarray() for W in tenpy_tensors]
1277-
modified_tensors = []
1275+
original_tensors_obj = tenpy_tensors
12781276

1279-
for i, (tensor, tenpy_t) in enumerate(zip(original_tensors, tenpy_tensors)):
1280-
labels = tenpy_t._labels
1277+
for i, W_obj in enumerate(original_tensors_obj):
1278+
arr = W_obj.to_ndarray()
1279+
labels = W_obj.get_leg_labels()
1280+
wL_idx = labels.index("wL")
1281+
p_idx = labels.index("p")
1282+
p_star_idx = labels.index("p*")
1283+
wR_idx = labels.index("wR")
1284+
1285+
arr_reordered = arr.transpose((wL_idx, p_idx, p_star_idx, wR_idx))
12811286
if nwires == 1:
1282-
tensor = np.take(tensor, [0], axis=labels.index(vl_label))
1283-
tensor = np.take(tensor, [-1], axis=labels.index(vr_label))
1287+
arr_reordered = arr_reordered[[0], :, :, :]
1288+
arr_reordered = arr_reordered[:, :, :, [-1]]
12841289
else:
12851290
if i == 0:
1286-
tensor = np.take(tensor, [0], axis=labels.index(vl_label))
1291+
arr_reordered = arr_reordered[[0], :, :, :]
12871292
elif i == nwires - 1:
1288-
tensor = np.take(tensor, [-1], axis=labels.index(vr_label))
1289-
modified_tensors.append(tensor)
1290-
1291-
for i, t in enumerate(modified_tensors):
1292-
if t.ndim == 4:
1293-
t = t.transpose((0, 2, 3, 1))
1294-
nodes.append(
1295-
Node(t, name=f"tensor_{i}", axis_names=["wL", "p", "p*", "wR"])
1293+
arr_reordered = arr_reordered[:, :, :, [-1]]
1294+
1295+
node = Node(
1296+
arr_reordered, name=f"mpo_{i}", axis_names=["wL", "p", "p*", "wR"]
12961297
)
1298+
nodes.append(node)
12971299

1298-
for i in range(nwires - 1):
1299-
connect(nodes[i]["wR"], nodes[i + 1]["wL"])
1300+
if nwires > 1:
1301+
for i in range(nwires - 1):
1302+
nodes[i][3] ^ nodes[i + 1][0]
13001303

1301-
out_edges = [node["p*"] for node in nodes]
1302-
in_edges = [node["p"] for node in nodes]
1303-
ignore_edges = [nodes[0]["wL"], nodes[-1]["wR"]]
1304+
out_edges = [n[2] for n in nodes]
1305+
in_edges = [n[1] for n in nodes]
1306+
ignore_edges = [nodes[0][0], nodes[-1][3]]
13041307
else: # MPS
1305-
nodes = []
13061308
for i in range(nwires):
13071309
B_obj = tenpy_obj.get_B(i)
13081310
arr = B_obj.to_ndarray()

0 commit comments

Comments
 (0)