@@ -1272,37 +1272,39 @@ def tenpy2qop(tenpy_obj: Any) -> QuOperator:
12721272
12731273 nodes = []
12741274 if is_mpo :
1275- vr_label , vl_label = "wR" , "wL"
1276- original_tensors = [W .to_ndarray () for W in tenpy_tensors ]
1277- modified_tensors = []
1275+ original_tensors_obj = tenpy_tensors
12781276
1279- for i , (tensor , tenpy_t ) in enumerate (zip (original_tensors , tenpy_tensors )):
1280- labels = tenpy_t ._labels
1277+ for i , W_obj in enumerate (original_tensors_obj ):
1278+ arr = W_obj .to_ndarray ()
1279+ labels = W_obj .get_leg_labels ()
1280+ wL_idx = labels .index ("wL" )
1281+ p_idx = labels .index ("p" )
1282+ p_star_idx = labels .index ("p*" )
1283+ wR_idx = labels .index ("wR" )
1284+
1285+ arr_reordered = arr .transpose ((wL_idx , p_idx , p_star_idx , wR_idx ))
12811286 if nwires == 1 :
1282- tensor = np . take ( tensor , [ 0 ], axis = labels . index ( vl_label ))
1283- tensor = np . take ( tensor , [- 1 ], axis = labels . index ( vr_label ))
1287+ arr_reordered = arr_reordered [[ 0 ], :, :, :]
1288+ arr_reordered = arr_reordered [:, :, :, [- 1 ]]
12841289 else :
12851290 if i == 0 :
1286- tensor = np . take ( tensor , [ 0 ], axis = labels . index ( vl_label ))
1291+ arr_reordered = arr_reordered [[ 0 ], :, :, :]
12871292 elif i == nwires - 1 :
1288- tensor = np .take (tensor , [- 1 ], axis = labels .index (vr_label ))
1289- modified_tensors .append (tensor )
1290-
1291- for i , t in enumerate (modified_tensors ):
1292- if t .ndim == 4 :
1293- t = t .transpose ((0 , 2 , 3 , 1 ))
1294- nodes .append (
1295- Node (t , name = f"tensor_{ i } " , axis_names = ["wL" , "p" , "p*" , "wR" ])
1293+ arr_reordered = arr_reordered [:, :, :, [- 1 ]]
1294+
1295+ node = Node (
1296+ arr_reordered , name = f"mpo_{ i } " , axis_names = ["wL" , "p" , "p*" , "wR" ]
12961297 )
1298+ nodes .append (node )
12971299
1298- for i in range (nwires - 1 ):
1299- connect (nodes [i ]["wR" ], nodes [i + 1 ]["wL" ])
1300+ if nwires > 1 :
1301+ for i in range (nwires - 1 ):
1302+ nodes [i ][3 ] ^ nodes [i + 1 ][0 ]
13001303
1301- out_edges = [node [ "p*" ] for node in nodes ]
1302- in_edges = [node [ "p" ] for node in nodes ]
1303- ignore_edges = [nodes [0 ]["wL" ], nodes [- 1 ]["wR" ]]
1304+ out_edges = [n [ 2 ] for n in nodes ]
1305+ in_edges = [n [ 1 ] for n in nodes ]
1306+ ignore_edges = [nodes [0 ][0 ], nodes [- 1 ][3 ]]
13041307 else : # MPS
1305- nodes = []
13061308 for i in range (nwires ):
13071309 B_obj = tenpy_obj .get_B (i )
13081310 arr = B_obj .to_ndarray ()
0 commit comments