Skip to content
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7edcdfe
E.C.
PietropaoloFrisoni Nov 4, 2025
66b0b99
Merge branch 'eclipse-qrisp:main' into dynamic_state_preparation
PietropaoloFrisoni Nov 10, 2025
5f61864
First raw version
PietropaoloFrisoni Nov 10, 2025
0563a94
Avoid zero division
PietropaoloFrisoni Nov 10, 2025
634e710
First generalized version (still needs to be optimized)
PietropaoloFrisoni Nov 14, 2025
bbbe0f2
Clarification plus sostitution of `qswitch_sequential` with `qswitch`
PietropaoloFrisoni Nov 17, 2025
e7bfc8e
[ci-skip]
PietropaoloFrisoni Nov 17, 2025
aa2ef07
Minimal improvement to the `qswitch` function
PietropaoloFrisoni Nov 17, 2025
420422e
Docstring to `qswitch_sequential`
PietropaoloFrisoni Nov 17, 2025
5925c62
Outdated documentation
PietropaoloFrisoni Nov 18, 2025
2aaceb3
Merge branch 'eclipse-qrisp:main' into dynamic_state_preparation
PietropaoloFrisoni Nov 18, 2025
6c9e9ab
First bunch of tests
PietropaoloFrisoni Nov 19, 2025
6e6a3e4
isort
PietropaoloFrisoni Nov 19, 2025
37cc526
Working on capturable algorithm version
PietropaoloFrisoni Nov 20, 2025
91ff37d
Separating preprocessing step and tests for jax
PietropaoloFrisoni Nov 21, 2025
cae901a
Unifying implementations
PietropaoloFrisoni Nov 21, 2025
b221a8c
Adding a `jasp_bit_reverse` function
PietropaoloFrisoni Nov 24, 2025
ad2d4f7
Moving implementation into different file
PietropaoloFrisoni Nov 25, 2025
481cca7
isort on init file
PietropaoloFrisoni Nov 25, 2025
6d7d283
Docstring with explanations
PietropaoloFrisoni Nov 25, 2025
7a7f9d8
Re-using function to normalize vector and improving names
PietropaoloFrisoni Nov 25, 2025
239b5d5
Changed one name
PietropaoloFrisoni Nov 25, 2025
22bac43
Better names and less local variables
PietropaoloFrisoni Nov 26, 2025
92b328f
Allowing `n=1` and removing special logic for the latter case
PietropaoloFrisoni Nov 26, 2025
cf5a2d7
more JASP tests
PietropaoloFrisoni Nov 26, 2025
2010229
Better names and pylint suggestions
PietropaoloFrisoni Nov 26, 2025
f3bdf96
Docstring(s)
PietropaoloFrisoni Nov 27, 2025
7a2c986
Documentation and removing warnings from JAX
PietropaoloFrisoni Nov 27, 2025
dc44eb1
Fixing typo in comment
PietropaoloFrisoni Nov 27, 2025
4061c3d
Replacing 2** with 1 << everywhere
PietropaoloFrisoni Nov 30, 2025
b42d4a9
Dynamic `jrange` loop in state preparaton
PietropaoloFrisoni Dec 2, 2025
66b592c
Triggering CI
PietropaoloFrisoni Dec 2, 2025
56d3a19
Increasing timeout minutes on CI
PietropaoloFrisoni Dec 3, 2025
74d49c4
Trying to increase time even more and moving some imports inside func…
PietropaoloFrisoni Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. _state_preparation:

Quantum State Preparation
=========================

.. currentmodule:: qrisp

.. autofunction:: state_preparation
1 change: 1 addition & 0 deletions src/qrisp/alg_primitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
from qrisp.alg_primitives.prepare import *
from qrisp.alg_primitives.iterative_qpe import *
from qrisp.alg_primitives.reflection import *
from qrisp.alg_primitives.state_preparation import *
336 changes: 336 additions & 0 deletions src/qrisp/alg_primitives/state_preparation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
"""
********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""

from typing import Callable

import jax
from jax import lax
import jax.numpy as jnp
import numpy as np

from qrisp.core.quantum_variable import QuantumVariable

EPSILON = jnp.sqrt(jnp.finfo(jnp.float64).eps)


def _rot_params_from_state(vec: jnp.ndarray) -> tuple:
"""
Computes the rotation angles to prepare a single qubit state,
where the amplitude of the |0> basis state is real and non-negative.

Specifically, it computes the angles ``theta``, ``phi``, and ``lambda``
such that applying the U3 gate with these angles to the |0> state results in the desired state:

|0> → a|0> + b|1>, with a real ≥ 0.

Parameters
----------
vec : jnp.ndarray
A 2-dimensional complex vector representing a qubit state.

Returns
-------
theta : float
The rotation angle theta.
phi : float
The rotation angle phi.
lam : float
The rotation angle lambda.
"""
a, b = vec
# We know that a is real (and non-negative).
# This step avoids warning about casting complex to real.
a = jnp.clip(jnp.real(a), -1.0, 1.0)
theta = 2.0 * jnp.arccos(a)
phi = jnp.where(jnp.abs(b) > EPSILON, jnp.angle(b), 0.0)
lam = 0.0
return theta, phi, lam


def _normalize_with_phase(
v: jnp.ndarray, acc: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
Normalizes a given vector and adjusts its phase.

The phase of the first element of the vector is removed and added to the accumulated phase.
The vector is normalized to have a unit norm and the first element is ensured to be real and non-negative.

Parameters
----------
v : jnp.ndarray
The child vector to normalize.
acc : jnp.ndarray
The accumulated phase from previous operations.

Returns
-------
norm : jnp.ndarray
The norm of the input vector.
v_normalized : jnp.ndarray
The normalized vector with adjusted phase.
updated_acc : jnp.ndarray
The updated accumulated phase.
"""

norm = jnp.linalg.norm(v)

def branch_nonzero(_):
alpha = jnp.angle(v[0])
v_normalized = v / (norm * jnp.exp(1j * alpha))
return norm, v_normalized, acc + alpha

def branch_zero(_):
# If the norm is zero, we return a default normalized vector
# with the first element real and non-negative.
v0 = jnp.where(jnp.real(v[0]) < 0, -v[0], v[0])
v_adj = v.at[0].set(v0)
return norm, v_adj, acc

return lax.cond(
norm > EPSILON,
lambda _: branch_nonzero(None),
lambda _: branch_zero(None),
operand=None,
)


def _compute_thetas(
vec: jnp.ndarray, acc: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
For a given input vector, this function computes the rotation angles
needed for the uniformly controlled RY at this tree layer, normalizes its child vectors,
and updates the accumulated phases for each child vector.

Parameters
----------
vec : jnp.ndarray
A complex vector representing the current vector to process.
acc : jnp.ndarray
The accumulated phase from previous operations.


Returns
-------
theta : jnp.ndarray
The angle (scalar array) for the ry rotation gate.
subvecs : jnp.ndarray
A 2D array where each row corresponds to a normalized subvector.
acc_phases : jnp.ndarray
A 1D array containing the updated accumulated phases for each subvector.

"""

len_vec = vec.shape[0]
half = len_vec // 2

v0 = vec[:half]
v1 = vec[half:]

n0, v0n, acc0 = _normalize_with_phase(v0, acc)
_, v1n, acc1 = _normalize_with_phase(v1, acc)

theta = 2.0 * jnp.arccos(jnp.minimum(1.0, n0)) # shape ()
subvecs = jnp.stack([v0n, v1n], axis=0) # shape (2, half)
acc_phases = jnp.stack([acc0, acc1], axis=0) # shape (2,)

return theta, subvecs, acc_phases


def _compute_u3_params(
qubit_vec: jnp.ndarray, acc: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""
For a given length-2 vector, this function computes the U3 gate parameters needed
to prepare the corresponding state, normalizes the vector, and updates the accumulated phase.

Parameters
----------
qubit_vec : jnp.ndarray
A complex vector representing a one-qubit state.
acc : jnp.ndarray
The accumulated phase from previous operations.

Returns
-------
u_params : jnp.ndarray
A 1D array containing the rotation angles (theta, phi, lambda) for the U3 gate.
total_phase : jnp.ndarray
The updated accumulated phase after processing the leaf subvector.

"""

_, vec_n, total_phase = _normalize_with_phase(qubit_vec, acc)
theta, phi, lam = _rot_params_from_state(vec_n)
return jnp.array([theta, phi, lam]), total_phase


# Here is the explanation of the data structures used in the state preparation algorithm:
#
# - `thetas`` has shape (n - 1, 2^(n-1)), and contains the ry rotation angles for each layer:
#
# thetas = Array[[theta_0_0, 0, 0, 0, ..., 0], # layer 0
# [theta_1_0, theta_1_1, 0, 0, ..., 0], # layer 1
# [theta_2_0, theta_2_1, theta_2_2, theta_2_3, ..., 0], # layer 2
# ...
# [theta_{n-2}_0, theta_{n-2}_1, ..., theta_{n-2}_{2^(n-2)-1}, 0]] # layer n-2
#
# - `u_params` has shape (2^(n-1), 3), and contains the U3 parameters for each leaf node.
#
# u_params = Array[[theta_leaf0, phi_leaf0, lam_leaf0], # leaf 0
# [theta_leaf1, phi_leaf1, lam_leaf1], # leaf 1
# ...,
# [theta_leaf_{2^(n-1)-1}, phi_leaf_{2^(n-1)-1}, lam_leaf_{2^(n-1)-1}]] # leaf 2^(n-1)-1
#
# - `phases` has shape (2^(n-1),), and contains the global phase for each leaf node.
#
# phases = Array[phase_leaf0, phase_leaf1, ..., phase_leaf_{2^(n-1)-1}]
#
def _preprocess(
target_array: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
This preprocessing function returns three data structures needed for state preparation.

Parameters
----------
target_array : jnp.ndarray
A complex vector representing the target state to prepare.

Returns
-------
thetas : jnp.ndarray
A 2D array containing the ry rotation angles for each layer.
u_params : jnp.ndarray
A 2D array containing the U3 parameters for each leaf node.
phases : jnp.ndarray
A 1D array containing the global phase for each leaf node.

"""

n = int(np.log2(target_array.shape[0]))
max_nodes = 1 << (n - 1)

# Data structures to return
thetas = jnp.zeros((n - 1, max_nodes), dtype=jnp.float64)
u_params = jnp.zeros((max_nodes, 3), dtype=jnp.float64)
phases = jnp.zeros(max_nodes, dtype=jnp.float64)

# Data structures used during the computation (reshaped at each layer)
subvecs = target_array[jnp.newaxis, :]
acc_phases = jnp.zeros((1,), dtype=jnp.float64)
for l in range(n):

num_nodes = 1 << l
sub_len = 1 << (n - l)

if sub_len == 2:
u_params_vec, phases_vec = jax.vmap(_compute_u3_params)(subvecs, acc_phases)
u_params = u_params.at[:num_nodes, :].set(u_params_vec)
phases = phases.at[:num_nodes].set(phases_vec)
break

theta_vec, subvecs, acc_phases = jax.vmap(_compute_thetas)(subvecs, acc_phases)
thetas = thetas.at[l, :num_nodes].set(theta_vec)
subvecs = subvecs.reshape((2 * num_nodes, sub_len // 2))
acc_phases = acc_phases.reshape((2 * num_nodes,))

return thetas, u_params, phases


def state_preparation(
qv: QuantumVariable, target_array: jnp.ndarray, method: str = "auto"
) -> None:
"""
Prepare the quantum state encoded in ``qv`` so that it matches the given
``target_array`` by constructing a binary-tree decomposition of the target
amplitudes and applying a sequence of uniformly controlled rotations via
the ``qswitch`` primitive.

This routine implements a standard state-preparation algorithm based on
recursively splitting the target statevector.
The classical preprocessing stage extracts RY angles for internal tree nodes
and U3 parameters for the leaf nodes.
The quantum stage applies them using ``qswitch``, which replaces
explicit multiplexers and conditionals in both static execution and Jasp mode.

.. note::

During the quantum stage, ``qswitch`` enumerates control patterns in
little-endian order, so each index is bit-reversed before accessing
the parameters computed in the classical preprocessing stage.

Parameters
----------
qv : QuantumVariable
The quantum variable representing the qubits to be prepared.
target_array : jnp.ndarray
A normalized complex vector representing the target state to prepare.
method : str, optional
The dispatch strategy for ``qswitch``. Default is "auto".

"""

# These imports are here to avoid circular dependencies
from qrisp import gphase, qswitch, ry, u3
from qrisp.misc.utility import bit_reverse

target_array = jnp.asarray(target_array, dtype=jnp.complex128)
# n is static (known at compile time), so we can use normal numpy here
n = int(np.log2(target_array.shape[0]))

thetas, u_params, phases = _preprocess(target_array)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This classical preprocessing function for n=20 takes 0.3 seconds on my laptop (each layer is vectorized with vmap). Therefore, most of the time is spent in the quantum part, which basically just calls qswitch (with a bit reverse operation, as we cannot slice a DynamicQubitArray backward)


def make_case_fn(layer_size: int, is_final: bool = False) -> Callable:
"""Create a case function for qswitch at a given layer."""

def case_fn(i, qb):
rev_idx = bit_reverse(i, layer_size)
if is_final:
theta_i, phi_i, lam_i = u_params[rev_idx]
u3(theta_i, phi_i, lam_i, qb)
gphase(phases[rev_idx], qb)
else:
ry(thetas[layer_size][rev_idx], qb)

return case_fn

if n == 1:
theta, phi, lam = u_params[0]
u3(theta, phi, lam, qv[0])
gphase(phases[0], qv[0])
return

ry(thetas[0][0], qv[0])

for layer_size in range(1, n - 1):

qswitch(
operand=qv[layer_size],
case=qv[:layer_size],
case_function=make_case_fn(layer_size),
method=method,
)

qswitch(
operand=qv[n - 1],
case=qv[: n - 1],
case_function=make_case_fn(n - 1, is_final=True),
method=method,
)
Loading