1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from collections import Counter
1516from functools import cached_property
16- from typing import cast , Dict , List , Tuple , Union
17+ from typing import cast , Dict , List , Sequence , Set , Tuple , Union
1718
18- from attrs import evolve , field , frozen , validators
19+ from attrs import field , frozen , validators
1920from numpy .typing import NDArray
2021from typing_extensions import Self
2122
2223from qualtran import (
24+ Bloq ,
2325 bloq_example ,
2426 BloqBuilder ,
2527 BloqDocSpec ,
3840from qualtran .bloqs .mcmt import MultiControlX
3941from qualtran .bloqs .reflections .prepare_identity import PrepareIdentity
4042from qualtran .bloqs .state_preparation .black_box_prepare import BlackBoxPrepare
41- from qualtran .symbolics import is_symbolic , prod , smax , ssum , SymbolicFloat , SymbolicInt
43+ from qualtran .resource_counting import BloqCountT , SympySymbolAllocator
44+ from qualtran .resource_counting .generalizers import ignore_split_join
45+ from qualtran .symbolics import HasLength , is_symbolic , prod , smax , ssum , SymbolicFloat , SymbolicInt
46+ from qualtran .symbolics .math_funcs import is_zero
4247
4348
4449@frozen
@@ -132,6 +137,54 @@ def epsilon(self) -> SymbolicFloat:
132137 def signal_state (self ) -> BlackBoxPrepare :
133138 return BlackBoxPrepare (PrepareIdentity .from_bitsizes ([self .ancilla_bitsize ]))
134139
140+ @property
141+ def anc_part (self ) -> Partition :
142+ n = len (self .block_encodings )
143+ anc_regs = []
144+ if n - 1 > 0 :
145+ anc_regs .append (Register ("flag_bits" , dtype = QBit (), shape = (n - 1 ,)))
146+ anc_bits = self .ancilla_bitsize - (n - 1 )
147+ if not is_zero (anc_bits ):
148+ anc_regs .append (Register ("ancilla" , dtype = QAny (anc_bits )))
149+ return Partition (cast (int , self .ancilla_bitsize ), tuple (anc_regs ))
150+
151+ @property
152+ def constituents (self ) -> Sequence [Bloq ]:
153+ n = len (self .block_encodings )
154+ anc_bits = self .ancilla_bitsize - (n - 1 )
155+ ret = []
156+ for u in reversed (self .block_encodings ):
157+ partition : List [Tuple [Register , List [Union [str , Unused ]]]] = [
158+ (Register ("system" , dtype = QAny (u .system_bitsize )), ["system" ])
159+ ]
160+ if is_symbolic (u .ancilla_bitsize ) or u .ancilla_bitsize > 0 :
161+ regs : List [Union [str , Unused ]] = ["ancilla" ]
162+ if (
163+ is_symbolic (anc_bits )
164+ or is_symbolic (u .ancilla_bitsize )
165+ or anc_bits > u .ancilla_bitsize
166+ ):
167+ regs .append (Unused (anc_bits - u .ancilla_bitsize ))
168+ partition .append ((Register ("ancilla" , dtype = QAny (anc_bits )), regs ))
169+ if not is_zero (u .resource_bitsize ):
170+ regs = ["resource" ]
171+ if is_symbolic (self .resource_bitsize ) or self .resource_bitsize > u .resource_bitsize :
172+ regs .append (Unused (self .resource_bitsize - u .resource_bitsize ))
173+ partition .append ((Register ("resource" , dtype = QAny (u .resource_bitsize )), regs ))
174+ ret .append (AutoPartition (u , partition , left_only = False ))
175+ return ret
176+
177+ def build_call_graph (self , ssa : SympySymbolAllocator ) -> Set [BloqCountT ]:
178+ counts = Counter [Bloq ]()
179+ for bloq in self .constituents :
180+ counts [bloq ] += 1
181+ n = len (self .block_encodings )
182+ for i , u in enumerate (reversed (self .block_encodings )):
183+ if not is_zero (u .ancilla_bitsize ) and n - 1 > 0 and i != n - 1 :
184+ counts [MultiControlX (HasLength (u .ancilla_bitsize ))] += 1
185+ counts [XGate ()] += 1
186+ return set (counts .items ())
187+
135188 def build_composite_bloq (
136189 self , bb : BloqBuilder , system : SoquetT , ** soqs : SoquetT
137190 ) -> Dict [str , SoquetT ]:
@@ -145,14 +198,8 @@ def build_composite_bloq(
145198
146199 if self .ancilla_bitsize > 0 :
147200 # partition ancilla into flag and inner ancilla
148- anc_regs = []
149- if n - 1 > 0 :
150- anc_regs .append (Register ("flag_bits" , dtype = QBit (), shape = (n - 1 ,)))
151201 anc_bits = self .ancilla_bitsize - (n - 1 )
152- if anc_bits > 0 :
153- anc_regs .append (Register ("ancilla" , dtype = QAny (anc_bits )))
154- anc_part = Partition (self .ancilla_bitsize , tuple (anc_regs ))
155- anc_part_soqs = bb .add_d (anc_part , x = soqs .pop ("ancilla" ))
202+ anc_part_soqs = bb .add_d (self .anc_part , x = soqs .pop ("ancilla" ))
156203 if n - 1 > 0 :
157204 flag_bits_soq = cast (NDArray , anc_part_soqs .pop ("flag_bits" ))
158205 if anc_bits > 0 :
@@ -169,22 +216,11 @@ def build_composite_bloq(
169216 assert not is_symbolic (u .ancilla_bitsize )
170217 assert not is_symbolic (u .resource_bitsize )
171218 u_soqs = {"system" : system }
172- partition : List [Tuple [Register , List [Union [str , Unused ]]]] = [
173- (Register ("system" , dtype = QAny (u .system_bitsize )), ["system" ])
174- ]
175219 if u .ancilla_bitsize > 0 :
176220 u_soqs ["ancilla" ] = anc_soq
177- regs : List [Union [str , Unused ]] = ["ancilla" ]
178- if anc_bits > u .ancilla_bitsize :
179- regs .append (Unused (anc_bits - u .ancilla_bitsize ))
180- partition .append ((Register ("ancilla" , dtype = QAny (anc_bits )), regs ))
181221 if u .resource_bitsize > 0 :
182222 u_soqs ["resource" ] = res_soq
183- regs = ["resource" ]
184- if self .resource_bitsize > u .resource_bitsize :
185- regs .append (Unused (self .resource_bitsize - u .resource_bitsize ))
186- partition .append ((Register ("resource" , dtype = QAny (u .resource_bitsize )), regs ))
187- u_out_soqs = bb .add_d (AutoPartition (u , partition , left_only = False ), ** u_soqs )
223+ u_out_soqs = bb .add_d (self .constituents [i ], ** u_soqs )
188224 system = u_out_soqs .pop ("system" )
189225 if u .ancilla_bitsize > 0 :
190226 anc_soq = u_out_soqs .pop ("ancilla" )
@@ -211,11 +247,11 @@ def build_composite_bloq(
211247 anc_soqs ["flag_bits" ] = flag_bits_soq
212248 if anc_bits > 0 :
213249 anc_soqs ["ancilla" ] = anc_soq
214- out ["ancilla" ] = cast (Soquet , bb .add (evolve ( anc_part , partition = False ), ** anc_soqs ))
250+ out ["ancilla" ] = cast (Soquet , bb .add (self . anc_part . adjoint ( ), ** anc_soqs ))
215251 return out
216252
217253
218- @bloq_example
254+ @bloq_example ( generalizer = ignore_split_join )
219255def _product_block_encoding () -> Product :
220256 from qualtran .bloqs .basic_gates import Hadamard , TGate
221257 from qualtran .bloqs .block_encoding .unitary import Unitary
0 commit comments