Skip to content

Commit d5b615c

Browse files
Add symbolic call graph for Product block encoding (#1352)
* Add symbolic call graph for `Product` block encoding * Address comments * Remove partition from counts
1 parent f28481c commit d5b615c

File tree

2 files changed

+65
-25
lines changed

2 files changed

+65
-25
lines changed

qualtran/bloqs/block_encoding/product.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import Counter
1516
from functools import cached_property
16-
from typing import cast, Dict, List, Tuple, Union
17+
from typing import cast, Dict, List, Sequence, Set, Tuple, Union
1718

18-
from attrs import evolve, field, frozen, validators
19+
from attrs import field, frozen, validators
1920
from numpy.typing import NDArray
2021
from typing_extensions import Self
2122

2223
from qualtran import (
24+
Bloq,
2325
bloq_example,
2426
BloqBuilder,
2527
BloqDocSpec,
@@ -38,7 +40,10 @@
3840
from qualtran.bloqs.mcmt import MultiControlX
3941
from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity
4042
from qualtran.bloqs.state_preparation.black_box_prepare import BlackBoxPrepare
41-
from qualtran.symbolics import is_symbolic, prod, smax, ssum, SymbolicFloat, SymbolicInt
43+
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
44+
from qualtran.resource_counting.generalizers import ignore_split_join
45+
from qualtran.symbolics import HasLength, is_symbolic, prod, smax, ssum, SymbolicFloat, SymbolicInt
46+
from qualtran.symbolics.math_funcs import is_zero
4247

4348

4449
@frozen
@@ -132,6 +137,54 @@ def epsilon(self) -> SymbolicFloat:
132137
def signal_state(self) -> BlackBoxPrepare:
133138
return BlackBoxPrepare(PrepareIdentity.from_bitsizes([self.ancilla_bitsize]))
134139

140+
@property
141+
def anc_part(self) -> Partition:
142+
n = len(self.block_encodings)
143+
anc_regs = []
144+
if n - 1 > 0:
145+
anc_regs.append(Register("flag_bits", dtype=QBit(), shape=(n - 1,)))
146+
anc_bits = self.ancilla_bitsize - (n - 1)
147+
if not is_zero(anc_bits):
148+
anc_regs.append(Register("ancilla", dtype=QAny(anc_bits)))
149+
return Partition(cast(int, self.ancilla_bitsize), tuple(anc_regs))
150+
151+
@property
152+
def constituents(self) -> Sequence[Bloq]:
153+
n = len(self.block_encodings)
154+
anc_bits = self.ancilla_bitsize - (n - 1)
155+
ret = []
156+
for u in reversed(self.block_encodings):
157+
partition: List[Tuple[Register, List[Union[str, Unused]]]] = [
158+
(Register("system", dtype=QAny(u.system_bitsize)), ["system"])
159+
]
160+
if is_symbolic(u.ancilla_bitsize) or u.ancilla_bitsize > 0:
161+
regs: List[Union[str, Unused]] = ["ancilla"]
162+
if (
163+
is_symbolic(anc_bits)
164+
or is_symbolic(u.ancilla_bitsize)
165+
or anc_bits > u.ancilla_bitsize
166+
):
167+
regs.append(Unused(anc_bits - u.ancilla_bitsize))
168+
partition.append((Register("ancilla", dtype=QAny(anc_bits)), regs))
169+
if not is_zero(u.resource_bitsize):
170+
regs = ["resource"]
171+
if is_symbolic(self.resource_bitsize) or self.resource_bitsize > u.resource_bitsize:
172+
regs.append(Unused(self.resource_bitsize - u.resource_bitsize))
173+
partition.append((Register("resource", dtype=QAny(u.resource_bitsize)), regs))
174+
ret.append(AutoPartition(u, partition, left_only=False))
175+
return ret
176+
177+
def build_call_graph(self, ssa: SympySymbolAllocator) -> Set[BloqCountT]:
178+
counts = Counter[Bloq]()
179+
for bloq in self.constituents:
180+
counts[bloq] += 1
181+
n = len(self.block_encodings)
182+
for i, u in enumerate(reversed(self.block_encodings)):
183+
if not is_zero(u.ancilla_bitsize) and n - 1 > 0 and i != n - 1:
184+
counts[MultiControlX(HasLength(u.ancilla_bitsize))] += 1
185+
counts[XGate()] += 1
186+
return set(counts.items())
187+
135188
def build_composite_bloq(
136189
self, bb: BloqBuilder, system: SoquetT, **soqs: SoquetT
137190
) -> Dict[str, SoquetT]:
@@ -145,14 +198,8 @@ def build_composite_bloq(
145198

146199
if self.ancilla_bitsize > 0:
147200
# partition ancilla into flag and inner ancilla
148-
anc_regs = []
149-
if n - 1 > 0:
150-
anc_regs.append(Register("flag_bits", dtype=QBit(), shape=(n - 1,)))
151201
anc_bits = self.ancilla_bitsize - (n - 1)
152-
if anc_bits > 0:
153-
anc_regs.append(Register("ancilla", dtype=QAny(anc_bits)))
154-
anc_part = Partition(self.ancilla_bitsize, tuple(anc_regs))
155-
anc_part_soqs = bb.add_d(anc_part, x=soqs.pop("ancilla"))
202+
anc_part_soqs = bb.add_d(self.anc_part, x=soqs.pop("ancilla"))
156203
if n - 1 > 0:
157204
flag_bits_soq = cast(NDArray, anc_part_soqs.pop("flag_bits"))
158205
if anc_bits > 0:
@@ -169,22 +216,11 @@ def build_composite_bloq(
169216
assert not is_symbolic(u.ancilla_bitsize)
170217
assert not is_symbolic(u.resource_bitsize)
171218
u_soqs = {"system": system}
172-
partition: List[Tuple[Register, List[Union[str, Unused]]]] = [
173-
(Register("system", dtype=QAny(u.system_bitsize)), ["system"])
174-
]
175219
if u.ancilla_bitsize > 0:
176220
u_soqs["ancilla"] = anc_soq
177-
regs: List[Union[str, Unused]] = ["ancilla"]
178-
if anc_bits > u.ancilla_bitsize:
179-
regs.append(Unused(anc_bits - u.ancilla_bitsize))
180-
partition.append((Register("ancilla", dtype=QAny(anc_bits)), regs))
181221
if u.resource_bitsize > 0:
182222
u_soqs["resource"] = res_soq
183-
regs = ["resource"]
184-
if self.resource_bitsize > u.resource_bitsize:
185-
regs.append(Unused(self.resource_bitsize - u.resource_bitsize))
186-
partition.append((Register("resource", dtype=QAny(u.resource_bitsize)), regs))
187-
u_out_soqs = bb.add_d(AutoPartition(u, partition, left_only=False), **u_soqs)
223+
u_out_soqs = bb.add_d(self.constituents[i], **u_soqs)
188224
system = u_out_soqs.pop("system")
189225
if u.ancilla_bitsize > 0:
190226
anc_soq = u_out_soqs.pop("ancilla")
@@ -211,11 +247,11 @@ def build_composite_bloq(
211247
anc_soqs["flag_bits"] = flag_bits_soq
212248
if anc_bits > 0:
213249
anc_soqs["ancilla"] = anc_soq
214-
out["ancilla"] = cast(Soquet, bb.add(evolve(anc_part, partition=False), **anc_soqs))
250+
out["ancilla"] = cast(Soquet, bb.add(self.anc_part.adjoint(), **anc_soqs))
215251
return out
216252

217253

218-
@bloq_example
254+
@bloq_example(generalizer=ignore_split_join)
219255
def _product_block_encoding() -> Product:
220256
from qualtran.bloqs.basic_gates import Hadamard, TGate
221257
from qualtran.bloqs.block_encoding.unitary import Unitary

qualtran/bloqs/block_encoding/product_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from qualtran.bloqs.state_preparation.black_box_prepare import BlackBoxPrepare
4545
from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle
4646
from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim
47-
from qualtran.testing import execute_notebook
47+
from qualtran.testing import assert_equivalent_bloq_example_counts, execute_notebook
4848

4949

5050
def test_product(bloq_autotester):
@@ -215,6 +215,10 @@ def test_product_signal_state():
215215
_ = _product_block_encoding().signal_state.decompose_bloq()
216216

217217

218+
def test_product_counts():
219+
assert_equivalent_bloq_example_counts(_product_block_encoding)
220+
221+
218222
@pytest.mark.notebook
219223
def test_notebook():
220224
execute_notebook('product')

0 commit comments

Comments
 (0)