Skip to content

Commit 488b9a5

Browse files
authored
Change build_call_graph in bloqs to return dict (#1392)
* Change build_call_graph in bloqs to return dict - This changes the build_call_graph function within bloqs in qualtran to return a dictionary of cost counts rather than a set. - This will allow the ordering of cost counts to be deterministic Note that this requires some slight code changes for bloqs that have multiple set items since (Toffoli(), 1) and (Toffoli(), 2) would have two different items in a set, but share an index in the dictionary. This also may alter counts (i.e. fix a bug) where set items clobber each other. For instance, adding (Toffoli(), self.bits_a) and (Toffoli(), self.bits_b) will previously give the wrong count if bits_a == bits_b since the two items would be the same in the set.
1 parent 098f7ea commit 488b9a5

File tree

100 files changed

+768
-801
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+768
-801
lines changed

qualtran/_infra/composite_bloq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
from qualtran.bloqs.bookkeeping.auto_partition import Unused
5353
from qualtran.cirq_interop._cirq_to_bloq import CirqQuregInT, CirqQuregT
54-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
54+
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
5555
from qualtran.simulation.classical_sim import ClassicalValT
5656

5757
# NDArrays must be bound to np.generic
@@ -237,7 +237,7 @@ def decompose_bloq(self) -> 'CompositeBloq':
237237
"Consider using the composite bloq directly or using `.flatten()`."
238238
)
239239

240-
def build_call_graph(self, ssa: Optional['SympySymbolAllocator']) -> Set['BloqCountT']:
240+
def build_call_graph(self, ssa: Optional['SympySymbolAllocator']) -> 'BloqCountDictT':
241241
"""Return the bloq counts by counting up all the subbloqs."""
242242
from qualtran.resource_counting import build_cbloq_call_graph
243243

qualtran/bloqs/arithmetic/_shims.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@
1919
will be fleshed out and moved to their final organizational location soon (written: 2024-05-06).
2020
"""
2121
from functools import cached_property
22-
from typing import Set
2322

2423
from attrs import frozen
2524

2625
from qualtran import Bloq, QBit, QUInt, Register, Signature
2726
from qualtran.bloqs.basic_gates import Toffoli
28-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
27+
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
2928

3029

3130
@frozen
@@ -36,8 +35,8 @@ class MultiCToffoli(Bloq):
3635
def signature(self) -> 'Signature':
3736
return Signature([Register('ctrl', QBit(), shape=(self.n,)), Register('target', QBit())])
3837

39-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
40-
return {(Toffoli(), self.n - 2)}
38+
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
39+
return {Toffoli(): self.n - 2}
4140

4241

4342
@frozen
@@ -51,9 +50,9 @@ def signature(self) -> 'Signature':
5150
[Register('x', QUInt(self.n)), Register('y', QUInt(self.n)), Register('out', QBit())]
5251
)
5352

54-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
53+
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
5554
# litinski
56-
return {(Toffoli(), self.n)}
55+
return {Toffoli(): self.n}
5756

5857

5958
@frozen

qualtran/bloqs/arithmetic/addition.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@
6262

6363
if TYPE_CHECKING:
6464
from qualtran.drawing import WireSymbol
65-
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
65+
from qualtran.resource_counting import (
66+
BloqCountDictT,
67+
BloqCountT,
68+
MutableBloqCountDictT,
69+
SympySymbolAllocator,
70+
)
6671
from qualtran.simulation.classical_sim import ClassicalValT
6772
from qualtran.symbolics import SymbolicInt
6873

@@ -209,10 +214,10 @@ def decompose_from_registers(
209214
yield CNOT().on(input_bits[0], output_bits[0])
210215
context.qubit_manager.qfree(ancillas)
211216

212-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
217+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
213218
n = self.b_dtype.bitsize
214219
n_cnot = (n - 2) * 6 + 3
215-
return {(And(), n - 1), (And().adjoint(), n - 1), (CNOT(), n_cnot)}
220+
return {And(): n - 1, And().adjoint(): n - 1, CNOT(): n_cnot}
216221

217222

218223
@bloq_example(generalizer=ignore_split_join)
@@ -327,8 +332,8 @@ def decompose_from_registers(
327332
]
328333
return cirq.inverse(optree) if self.is_adjoint else optree
329334

330-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
331-
return {(And(uncompute=self.is_adjoint), self.bitsize), (CNOT(), 5 * self.bitsize)}
335+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
336+
return {And(uncompute=self.is_adjoint): self.bitsize, CNOT(): 5 * self.bitsize}
332337

333338
def __pow__(self, power: int):
334339
if power == 1:
@@ -505,16 +510,16 @@ def build_composite_bloq(
505510
def build_call_graph(
506511
self, ssa: 'SympySymbolAllocator'
507512
) -> Union['BloqCountDictT', Set['BloqCountT']]:
508-
loading_cost: Tuple[Bloq, SymbolicInt]
513+
loading_cost: MutableBloqCountDictT
509514
if len(self.cvs) == 0:
510-
loading_cost = (XGate(), self.bitsize) # upper bound; depends on the data.
515+
loading_cost = {XGate(): self.bitsize} # upper bound; depends on the data.
511516
elif len(self.cvs) == 1:
512-
loading_cost = (CNOT(), self.bitsize) # upper bound; depends on the data.
517+
loading_cost = {CNOT(): self.bitsize} # upper bound; depends on the data.
513518
else:
514519
# Otherwise, use the decomposition
515520
return super().build_call_graph(ssa=ssa)
516-
517-
return {loading_cost, (Add(QUInt(self.bitsize)), 1)}
521+
loading_cost[Add(QUInt(self.bitsize))] = 1
522+
return loading_cost
518523

519524
def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']:
520525
if self.cvs:

qualtran/bloqs/arithmetic/bitwise.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from qualtran.symbolics import is_symbolic, SymbolicInt
4040

4141
if TYPE_CHECKING:
42-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
42+
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
4343
from qualtran.simulation.classical_sim import ClassicalValT
4444

4545

@@ -90,9 +90,9 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: 'Soquet') -> dict[str, 'Soq
9090

9191
return {'x': x}
9292

93-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> set['BloqCountT']:
93+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
9494
num_flips = self.bitsize if self.is_symbolic() else sum(self._bits_k)
95-
return {(XGate(), num_flips)}
95+
return {XGate(): num_flips}
9696

9797
def on_classical_vals(self, x: 'ClassicalValT') -> dict[str, 'ClassicalValT']:
9898
if isinstance(self.k, sympy.Expr):
@@ -156,8 +156,8 @@ def build_composite_bloq(self, bb: BloqBuilder, x: Soquet, y: Soquet) -> dict[st
156156

157157
return {'x': bb.join(xs, dtype=self.dtype), 'y': bb.join(ys, dtype=self.dtype)}
158158

159-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> set['BloqCountT']:
160-
return {(CNOT(), self.dtype.num_qubits)}
159+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
160+
return {CNOT(): self.dtype.num_qubits}
161161

162162
def on_classical_vals(
163163
self, x: 'ClassicalValT', y: 'ClassicalValT'

qualtran/bloqs/arithmetic/comparison.py

Lines changed: 51 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,7 @@
1414

1515
from collections import defaultdict
1616
from functools import cached_property
17-
from typing import (
18-
Dict,
19-
Iterable,
20-
Iterator,
21-
List,
22-
Optional,
23-
Sequence,
24-
Set,
25-
Tuple,
26-
TYPE_CHECKING,
27-
Union,
28-
)
17+
from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
2918

3019
import attrs
3120
import cirq
@@ -65,7 +54,11 @@
6554

6655
if TYPE_CHECKING:
6756
from qualtran import BloqBuilder
68-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
57+
from qualtran.resource_counting import (
58+
BloqCountDictT,
59+
MutableBloqCountDictT,
60+
SympySymbolAllocator,
61+
)
6962
from qualtran.simulation.classical_sim import ClassicalValT
7063

7164

@@ -183,22 +176,22 @@ def decompose_from_registers(
183176
def _has_unitary_(self):
184177
return True
185178

186-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
179+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
187180
if (
188181
not is_symbolic(self.less_than_val, self.bitsize)
189182
and self.less_than_val >= 2**self.bitsize
190183
):
191-
return {(XGate(), 1)}
184+
return {XGate(): 1}
192185
num_set_bits = (
193186
int(self.less_than_val).bit_count()
194187
if not is_symbolic(self.less_than_val)
195188
else self.bitsize
196189
)
197190
return {
198-
(And(), self.bitsize),
199-
(And().adjoint(), self.bitsize),
200-
(CNOT(), num_set_bits + 2 * self.bitsize),
201-
(XGate(), 2 * (1 + num_set_bits)),
191+
And(): self.bitsize,
192+
And().adjoint(): self.bitsize,
193+
CNOT(): num_set_bits + 2 * self.bitsize,
194+
XGate(): 2 * (1 + num_set_bits),
202195
}
203196

204197

@@ -307,8 +300,8 @@ def __pow__(self, power: int) -> 'BiQubitsMixer':
307300
return self.adjoint()
308301
return NotImplemented # pragma: no cover
309302

310-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
311-
return {(XGate(), 1), (CNOT(), 9), (And(uncompute=self.is_adjoint), 2)}
303+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
304+
return {XGate(): 1, CNOT(): 9, And(uncompute=self.is_adjoint): 2}
312305

313306
def _has_unitary_(self):
314307
return not self.is_adjoint
@@ -380,8 +373,8 @@ def __pow__(self, power: int) -> Union['SingleQubitCompare', cirq.Gate]:
380373
return self.adjoint()
381374
return self
382375

383-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
384-
return {(XGate(), 1), (CNOT(), 4), (And(uncompute=self.is_adjoint), 1)}
376+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
377+
return {XGate(): 1, CNOT(): 4, And(uncompute=self.is_adjoint): 1}
385378

386379

387380
@bloq_example
@@ -575,13 +568,13 @@ def decompose_from_registers(
575568
all_ancilla = set([q for op in adjoint for q in op.qubits if q not in input_qubits])
576569
context.qubit_manager.qfree(all_ancilla)
577570

578-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
571+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
579572
if is_symbolic(self.x_bitsize, self.y_bitsize):
580573
return {
581-
(BiQubitsMixer(), self.x_bitsize),
582-
(BiQubitsMixer().adjoint(), self.x_bitsize),
583-
(SingleQubitCompare(), 1),
584-
(SingleQubitCompare().adjoint(), 1),
574+
BiQubitsMixer(): self.x_bitsize,
575+
BiQubitsMixer().adjoint(): self.x_bitsize,
576+
SingleQubitCompare(): 1,
577+
SingleQubitCompare().adjoint(): 1,
585578
}
586579

587580
n = min(self.x_bitsize, self.y_bitsize)
@@ -613,7 +606,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
613606
ret[And(1, 0).adjoint()] += 1
614607
ret[CNOT()] += 1
615608

616-
return set(ret.items())
609+
return ret
617610

618611
def _has_unitary_(self):
619612
return True
@@ -691,8 +684,8 @@ def build_composite_bloq(
691684
target = bb.add(XGate(), q=target)
692685
return {'a': a, 'b': b, 'target': target}
693686

694-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
695-
return {(LessThanEqual(self.a_bitsize, self.b_bitsize), 1), (XGate(), 1)}
687+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
688+
return {LessThanEqual(self.a_bitsize, self.b_bitsize): 1, XGate(): 1}
696689

697690

698691
@bloq_example
@@ -885,23 +878,23 @@ def wire_symbol(
885878
return TextBox('t⨁(a>b)')
886879
raise ValueError(f'Unknown register name {reg.name}')
887880

888-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
881+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
889882
if self.bitsize == 1:
890-
return {(MultiControlX(cvs=(1, 0)), 1)}
883+
return {MultiControlX(cvs=(1, 0)): 1}
891884

892885
if self.signed:
893886
return {
894-
(CNOT(), 6 * self.bitsize - 7),
895-
(XGate(), 2 * self.bitsize + 2),
896-
(And(), self.bitsize - 1),
897-
(And(uncompute=True), self.bitsize - 1),
887+
CNOT(): 6 * self.bitsize - 7,
888+
XGate(): 2 * self.bitsize + 2,
889+
And(): self.bitsize - 1,
890+
And(uncompute=True): self.bitsize - 1,
898891
}
899892

900893
return {
901-
(CNOT(), 6 * self.bitsize - 1),
902-
(XGate(), 2 * self.bitsize + 4),
903-
(And(), self.bitsize),
904-
(And(uncompute=True), self.bitsize),
894+
CNOT(): 6 * self.bitsize - 1,
895+
XGate(): 2 * self.bitsize + 4,
896+
And(): self.bitsize,
897+
And(uncompute=True): self.bitsize,
905898
}
906899

907900

@@ -941,8 +934,8 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
941934
return TextBox(f"⨁(x > {self.val})")
942935
raise ValueError(f'Unknown register symbol {reg.name}')
943936

944-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
945-
return {(LessThanConstant(self.bitsize, less_than_val=self.val), 1)}
937+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
938+
return {LessThanConstant(self.bitsize, less_than_val=self.val): 1}
946939

947940

948941
@bloq_example
@@ -1007,8 +1000,8 @@ def build_composite_bloq(
10071000
x = bb.join(xs)
10081001
return {'x': x, 'target': target}
10091002

1010-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
1011-
return {(MultiControlX(self.bits_k), 1)}
1003+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
1004+
return {MultiControlX(self.bits_k): 1}
10121005

10131006

10141007
def _make_equals_a_constant():
@@ -1134,21 +1127,22 @@ def on_classical_vals(
11341127
return {'ctrl': ctrl, 'a': a, 'b': b, 'target': target ^ (a > b)}
11351128
return {'ctrl': ctrl, 'a': a, 'b': b, 'target': target}
11361129

1137-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
1138-
signed_ops = []
1130+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
1131+
signed_ops: 'MutableBloqCountDictT' = {}
11391132
if isinstance(self.dtype, QInt):
1140-
signed_ops = [
1141-
(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), 2),
1142-
(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), 2),
1143-
]
1133+
signed_ops = {
1134+
SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)): 2,
1135+
SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(): 2,
1136+
}
11441137
dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1)
11451138
return {
1146-
(BitwiseNot(dtype), 2),
1147-
(BitwiseNot(QUInt(dtype.bitsize + 1)), 2),
1148-
(OutOfPlaceAdder(self.dtype.bitsize + 1).adjoint(), 1),
1149-
(OutOfPlaceAdder(self.dtype.bitsize + 1), 1),
1150-
(MultiControlX((self.cv, 1)), 1),
1151-
}.union(signed_ops)
1139+
BitwiseNot(dtype): 2,
1140+
BitwiseNot(QUInt(dtype.bitsize + 1)): 2,
1141+
OutOfPlaceAdder(self.dtype.bitsize + 1).adjoint(): 1,
1142+
OutOfPlaceAdder(self.dtype.bitsize + 1): 1,
1143+
MultiControlX((self.cv, 1)): 1,
1144+
**signed_ops,
1145+
}
11521146

11531147

11541148
@bloq_example(generalizer=ignore_split_join)

qualtran/bloqs/arithmetic/controlled_addition.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, Set, TYPE_CHECKING, Union
15+
from typing import Dict, TYPE_CHECKING, Union
1616

1717
import numpy as np
1818
import sympy
@@ -42,7 +42,7 @@
4242
import quimb.tensor as qtn
4343

4444
from qualtran.drawing import WireSymbol
45-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
45+
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
4646
from qualtran.simulation.classical_sim import ClassicalValT
4747

4848

@@ -155,11 +155,11 @@ def build_composite_bloq(
155155
ctrl = bb.join(np.array([ctrl_q]))
156156
return {'ctrl': ctrl, 'a': a, 'b': b}
157157

158-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
158+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
159159
return {
160-
(And(self.cv, 1), self.a_dtype.bitsize),
161-
(Add(self.a_dtype, self.b_dtype), 1),
162-
(And(self.cv, 1).adjoint(), self.a_dtype.bitsize),
160+
And(self.cv, 1): self.a_dtype.bitsize,
161+
Add(self.a_dtype, self.b_dtype): 1,
162+
And(self.cv, 1).adjoint(): self.a_dtype.bitsize,
163163
}
164164

165165

0 commit comments

Comments
 (0)