Skip to content
176 changes: 136 additions & 40 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

import abc
import copy
import enum
import html
import itertools
Expand All @@ -35,7 +36,6 @@
Mapping,
MutableSequence,
Sequence,
Set,
)
from types import NotImplementedType
from typing import Any, cast, overload, Self, TYPE_CHECKING, TypeVar, Union
Expand Down Expand Up @@ -1335,10 +1335,10 @@ def _is_parameterized_(self) -> bool:
protocols.is_parameterized(tag) for tag in self.tags
)

def _parameter_names_(self) -> Set[str]:
def _parameter_names_(self) -> frozenset[str]:
op_params = {name for op in self.all_operations() for name in protocols.parameter_names(op)}
tag_params = {name for tag in self.tags for name in protocols.parameter_names(tag)}
return op_params | tag_params
return frozenset(op_params | tag_params)

def _resolve_parameters_(self, resolver: cirq.ParamResolver, recursive: bool) -> Self:
changed = False
Expand Down Expand Up @@ -1839,7 +1839,7 @@ def __init__(
self._frozen: cirq.FrozenCircuit | None = None
self._is_measurement: bool | None = None
self._is_parameterized: bool | None = None
self._parameter_names: Set[str] | None = None
self._parameter_names: frozenset[str] | None = None
if not contents:
return
flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents))
Expand Down Expand Up @@ -1948,7 +1948,7 @@ def _is_parameterized_(self) -> bool:
self._is_parameterized = super()._is_parameterized_()
return self._is_parameterized

def _parameter_names_(self) -> Set[str]:
def _parameter_names_(self) -> frozenset[str]:
if self._parameter_names is None:
self._parameter_names = super()._parameter_names_()
return self._parameter_names
Expand All @@ -1957,10 +1957,30 @@ def copy(self) -> Circuit:
"""Return a copy of this circuit."""
copied_circuit = Circuit()
copied_circuit._moments[:] = self._moments
copied_circuit._placement_cache = None
copied_circuit._tags = self.tags
copied_circuit._all_qubits = self._all_qubits
copied_circuit._frozen = self._frozen
copied_circuit._is_measurement = self._is_measurement
copied_circuit._is_parameterized = self._is_parameterized
copied_circuit._parameter_names = self._parameter_names
copied_circuit._placement_cache = copy.copy(self._placement_cache)
return copied_circuit

def _copy_from_shallow(self, other: Circuit) -> None:
"""Copies the contents of another circuit into this one.

This performs a shallow copy from another circuit. It is primarily intended for reimporting
data from temporary copies that were created during multistep mutations to allow them to be
performed atomically."""
self._moments = other._moments
self._tags = other.tags
self._all_qubits = other._all_qubits
self._frozen = other._frozen
self._is_measurement = other._is_measurement
self._is_parameterized = other._is_parameterized
self._parameter_names = other._parameter_names
self._placement_cache = other._placement_cache

@overload
def __setitem__(self, key: int, value: cirq.Moment):
pass
Expand Down Expand Up @@ -2002,7 +2022,7 @@ def __radd__(self, other):
return NotImplemented
# Auto wrap OP_TREE inputs into a circuit.
result = self.copy()
result._moments[:0] = Circuit(other)._moments
result._insert_moments(0, *Circuit(other)._moments)
return result

# Needed for numpy to handle multiplication by np.int64 correctly.
Expand All @@ -2011,19 +2031,19 @@ def __radd__(self, other):
def __imul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
num_moments_added = len(self._moments) * (repetitions - 1)
self._moments *= int(repetitions)
self._mutated()
if self._placement_cache:
# Shift everything `num_moments_added` to the right.
self._placement_cache.insert_moments(0, num_moments_added)
self._frozen = None # All other cache values are resilient to mul.
return self

def __mul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
return Circuit(self._moments * int(repetitions), tags=self.tags)
return self.copy().__imul__(repetitions)

def __rmul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
return self * int(repetitions)
return self.copy().__imul__(repetitions)

def __pow__(self, exponent: int) -> cirq.Circuit:
"""A circuit raised to a power, only valid for exponent -1, the inverse.
Expand Down Expand Up @@ -2186,10 +2206,9 @@ def insert(
"""
# limit index to 0..len(self._moments), also deal with indices smaller 0
k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0)
if strategy != InsertStrategy.EARLIEST or k != len(self._moments):
self._placement_cache = None
appending = strategy == InsertStrategy.EARLIEST and k == len(self._moments)
mops = list(ops.flatten_to_ops_or_moments(moment_or_operation_tree))
if self._placement_cache:
if self._placement_cache and appending:
batches = [mops] # Any grouping would work here; this just happens to be the fastest.
elif strategy is InsertStrategy.NEW:
batches = [[mop] for mop in mops] # Each op goes into its own moment.
Expand All @@ -2198,7 +2217,7 @@ def insert(
for batch in batches:
# Insert a moment if inline/earliest and _any_ op in the batch requires it.
if (
not self._placement_cache
not appending
and not isinstance(batch[0], Moment)
and strategy in (InsertStrategy.INLINE, InsertStrategy.EARLIEST)
and not all(
Expand All @@ -2207,39 +2226,86 @@ def insert(
for op in cast(list[cirq.Operation], batch)
)
):
self._moments.insert(k, Moment())
self._insert_moments(k)
if strategy is InsertStrategy.INLINE:
k += 1
max_p = 0
for moment_or_op in batch:
# Determine Placement
if self._placement_cache:
cache_updated = False
if self._placement_cache and appending:
# This updates the cache and returns placement in a single step. It would be
# cleaner to "check" placement here and avoid the special `skip_cache_update`
# args below, but that adds about 15% latency to this perf-critical case.
p = self._placement_cache.append(moment_or_op)
cache_updated = True
elif isinstance(moment_or_op, Moment):
p = k
elif strategy in (InsertStrategy.NEW, InsertStrategy.NEW_THEN_INLINE):
self._moments.insert(k, Moment())
self._insert_moments(k)
p = k
elif strategy is InsertStrategy.INLINE:
p = k - 1
else: # InsertStrategy.EARLIEST:
p = self.earliest_available_moment(moment_or_op, end_moment_index=k)
# Place
if isinstance(moment_or_op, Moment):
self._moments.insert(p, moment_or_op)
elif p == len(self._moments):
self._moments.append(Moment(moment_or_op))
self._insert_moments(p, moment_or_op, skip_cache_update=cache_updated)
else:
self._moments[p] = self._moments[p].with_operation(moment_or_op)
self._put_ops(p, moment_or_op, skip_cache_update=cache_updated)
# Iterate
max_p = max(p, max_p)
if strategy is InsertStrategy.NEW_THEN_INLINE:
strategy = InsertStrategy.INLINE
k += 1
k = max(k, max_p + 1)
self._mutated(preserve_placement_cache=True)
return k

def _insert_moments(
self, index: int, *moments: Moment, count: int = 1, skip_cache_update: bool = False
):
"""Inserts moments directly before circuit[index] and updates caches.

Args:
index: The moment index to insert the moment. If greater than the circuit length, the
moments will be appended.
moments: The moments to insert. If none are provided, a single empty moment will be
assumed.
count: The number of moments to insert. If both `moments` and `count` are provided,
the provided moments will be inserted `count` times.
skip_cache_update: Skips updates to the placement cache. Only use if the placement cache
has already been updated.
"""
if not moments:
moments = (Moment(),)
moments *= count
self._moments[index:index] = moments
if self._placement_cache and not skip_cache_update:
self._placement_cache.insert_moments(index, len(moments))
for i, m in enumerate(moments):
self._placement_cache.put(index + i, m)
self._mutated(preserve_placement_cache=True)

def _put_ops(self, index: int, *ops: cirq.Operation, skip_cache_update: bool = False):
"""Adds operations directly to circuit[index] and updates caches.

This is intended to be low-level and will fail if the moment does not exist or already has
conflicting operations.

Args:
index: The moment index to add operations to.
ops: The operations to add.
skip_cache_update: Skips updates to the placement cache. Only use if the placement cache
has already been updated.
"""
if index == len(self._moments):
self._moments.append(Moment.from_ops(*ops))
else:
self._moments[index] = self._moments[index].with_operations(*ops)
if self._placement_cache and not skip_cache_update:
self._placement_cache.put(index, *ops)
self._mutated(preserve_placement_cache=True)

def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> int:
"""Writes operations inline into an area of the circuit.

Expand Down Expand Up @@ -2272,9 +2338,8 @@ def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> i
if i >= end:
break

self._moments[i] = self._moments[i].with_operation(op)
self._put_ops(i, op)
op_index += 1
self._mutated()

if op_index >= len(flat_ops):
return end
Expand Down Expand Up @@ -2319,8 +2384,7 @@ def _push_frontier(
)
if n_new_moments > 0:
insert_index = min(late_frontier.values())
self._moments[insert_index:insert_index] = [Moment()] * n_new_moments
self._mutated()
self._insert_moments(insert_index, count=n_new_moments)
for q in update_qubits:
if early_frontier.get(q, 0) > insert_index:
early_frontier[q] += n_new_moments
Expand All @@ -2346,13 +2410,12 @@ def _insert_operations(
"""
if len(operations) != len(insertion_indices):
raise ValueError('operations and insertion_indices must have the same length.')
self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))]
self._mutated()
self._insert_moments(len(self), count=1 + max(insertion_indices) - len(self))
moment_to_ops: dict[int, list[cirq.Operation]] = defaultdict(list)
for op_index, moment_index in enumerate(insertion_indices):
moment_to_ops[moment_index].append(operations[op_index])
for moment_index, new_ops in moment_to_ops.items():
self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops)
self._put_ops(moment_index, *new_ops)

def insert_at_frontier(
self, operations: cirq.OP_TREE, start: int, frontier: dict[cirq.Qid, int] | None = None
Expand Down Expand Up @@ -2455,9 +2518,8 @@ def batch_insert_into(self, insert_intos: Iterable[tuple[int, cirq.OP_TREE]]) ->
"""
copy = self.copy()
for i, insertions in insert_intos:
copy._moments[i] = copy._moments[i].with_operations(insertions)
self._moments = copy._moments
self._mutated()
copy._put_ops(i, *ops.flatten_to_ops(insertions))
self._copy_from_shallow(copy)

def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None:
"""Applies a batched insert operation to the circuit.
Expand Down Expand Up @@ -2491,8 +2553,7 @@ def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None:
next_index = copy.insert(insert_index, reversed(group), InsertStrategy.EARLIEST)
if next_index > insert_index:
shift += next_index - insert_index
self._moments = copy._moments
self._mutated()
self._copy_from_shallow(copy)

def append(
self,
Expand Down Expand Up @@ -2537,8 +2598,8 @@ def with_tags(self, *new_tags: Hashable) -> cirq.Circuit:
"""Creates a new tagged `Circuit` with `self.tags` and `new_tags` combined."""
if not new_tags:
return self
new_circuit = Circuit(tags=self.tags + new_tags)
new_circuit._moments[:] = self._moments
new_circuit = self.copy()
new_circuit._tags = self.tags + new_tags
return new_circuit

def with_noise(self, noise: cirq.NOISE_MODEL_LIKE) -> cirq.Circuit:
Expand Down Expand Up @@ -3062,3 +3123,38 @@ def append(self, moment_or_operation: _MOMENT_OR_OP) -> int:
)
self._length = max(self._length, index + 1)
return index

def insert_moments(self, index: int, count: int = 1) -> None:
"""Updates cache to account for empty moments inserted at circuit[index]."""
self._insert_moments(self._qubit_indices, index, count)
self._insert_moments(self._mkey_indices, index, count)
self._insert_moments(self._ckey_indices, index, count)
self._length += count

def put(self, index: int, *moments_or_operations: _MOMENT_OR_OP) -> None:
"""Updates cache to account for ops added to circuit[index]."""
for mop in moments_or_operations:
self._put(self._qubit_indices, mop.qubits, index)
self._put(self._mkey_indices, protocols.measurement_key_objs(mop), index)
self._put(self._ckey_indices, protocols.control_keys(mop), index)
self._length = max(self._length, index + 1)

@staticmethod
def _put(key_indices: dict[_TKey, int], mop_keys: Iterable[_TKey], mop_index: int) -> None:
for key in mop_keys:
key_indices[key] = max(mop_index, key_indices.get(key, -1))

@staticmethod
def _insert_moments(key_indices: dict[_TKey, int], index: int, count: int) -> None:
for key in key_indices:
key_index = key_indices[key]
if key_index >= index:
key_indices[key] = key_index + count

def __copy__(self) -> _PlacementCache:
cache = _PlacementCache()
cache._qubit_indices = self._qubit_indices.copy()
cache._mkey_indices = self._mkey_indices.copy()
cache._ckey_indices = self._ckey_indices.copy()
cache._length = self._length
return cache
Loading