Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d17c340
Fall back to op.decompose if op is unsolved in decomposition graph
astralcai Aug 28, 2025
116dcf9
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Aug 28, 2025
daa19b2
changelog
astralcai Aug 28, 2025
2e7176b
add tests
astralcai Aug 28, 2025
cef4506
ooops
astralcai Aug 28, 2025
706251a
oooops
astralcai Aug 28, 2025
a0ee55a
fix test
astralcai Aug 28, 2025
8382255
fix more tests
astralcai Aug 28, 2025
98cb1fe
Merge branch 'master' into decomp-fallback
astralcai Aug 28, 2025
592c548
default value for resource_params
astralcai Aug 29, 2025
77fcace
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Aug 29, 2025
8f9ce87
Merge branch 'master' into decomp-fallback
JerryChen97 Aug 29, 2025
5f4ceea
Merge branch 'master' into decomp-fallback
JerryChen97 Sep 2, 2025
c822529
Apply suggestions from code review
astralcai Sep 2, 2025
f08be18
Merge branch 'master' into decomp-fallback
astralcai Sep 2, 2025
1cf9389
Update doc/releases/changelog-dev.md
JerryChen97 Sep 3, 2025
b54581a
Merge branch 'master' into decomp-fallback
JerryChen97 Sep 3, 2025
8307d5f
Merge branch 'master' into decomp-fallback
astralcai Sep 3, 2025
34cdd48
[WIP] Unified internal implementation for decompositions
astralcai Sep 3, 2025
b8aa9c2
unused import
astralcai Sep 3, 2025
7f41082
Merge branch 'master' into unified-decomp
astralcai Sep 4, 2025
bca42f9
fix something
astralcai Sep 4, 2025
78703f1
Merge branch 'master' into decomp-fallback
astralcai Sep 4, 2025
4f5926f
more updates
astralcai Sep 4, 2025
6b34749
Merge branch 'decomp-fallback' of https://github.com/PennyLaneAI/penn…
astralcai Sep 4, 2025
a8c656a
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Sep 4, 2025
bd56fa4
coverage
astralcai Sep 4, 2025
54f874a
Merge branch 'master' into decomp-fallback
JerryChen97 Sep 4, 2025
2b2e682
changelog
astralcai Sep 4, 2025
9e2c2c1
Merge branch 'decomp-fallback' into unified-decomp
JerryChen97 Sep 4, 2025
9763858
Merge branch 'master' into unified-decomp
JerryChen97 Sep 4, 2025
81a53c8
Update pennylane/transforms/decompose.py
astralcai Sep 4, 2025
c3b0b41
Merge branch 'master' into unified-decomp
JerryChen97 Sep 5, 2025
7e057e0
Update pennylane/transforms/decompose.py
astralcai Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,9 @@
execution on null devices.
[(#8090)](https://github.com/PennyLaneAI/pennylane/pull/8090)

* :func:`.transforms.decompose` and :func:`.preprocess.decompose` now have a unified internal implementation.
[(#8193)](https://github.com/PennyLaneAI/pennylane/pull/8193)

<h3>Documentation 📝</h3>

* Rename `ancilla` to `auxiliary` in internal documentation.
Expand Down
90 changes: 29 additions & 61 deletions pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,26 @@

import os
import warnings
from collections.abc import Callable, Generator, Sequence
from collections.abc import Callable, Sequence
from copy import copy

import pennylane as qml
from pennylane.exceptions import AllocationError, DeviceError, QuantumFunctionError, WireError
from pennylane.decomposition.decomposition_graph import DecompGraphSolution
from pennylane.exceptions import (
AllocationError,
DecompositionUndefinedError,
DeviceError,
QuantumFunctionError,
WireError,
)
from pennylane.math import requires_grad
from pennylane.measurements import SampleMeasurement, StateMeasurement
from pennylane.operation import StatePrepBase
from pennylane.operation import Operator, StatePrepBase
from pennylane.ops import Snapshot
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms import resolve_dynamic_wires
from pennylane.transforms.core import transform
from pennylane.transforms.decompose import _operator_decomposition_gen
from pennylane.typing import PostprocessingFn
from pennylane.wires import Wires

Expand All @@ -43,45 +51,6 @@ def null_postprocessing(results):
return results[0]


def _operator_decomposition_gen( # pylint: disable = too-many-positional-arguments
op: qml.operation.Operator,
acceptance_function: Callable[[qml.operation.Operator], bool],
decomposer: Callable[[qml.operation.Operator], Sequence[qml.operation.Operator]],
max_expansion: int | None = None,
current_depth=0,
name: str = "device",
error: type[Exception] | None = None,
) -> Generator[qml.operation.Operator, None, None]:
"""A generator that yields the next operation that is accepted."""
if error is None:
error = DeviceError

max_depth_reached = False
if max_expansion is not None and max_expansion <= current_depth:
max_depth_reached = True
if acceptance_function(op) or max_depth_reached:
yield op
else:
try:
decomp = decomposer(op)
current_depth += 1
except qml.operation.DecompositionUndefinedError as e:
raise error(
f"Operator {op} not supported with {name} and does not provide a decomposition."
) from e

for sub_op in decomp:
yield from _operator_decomposition_gen(
sub_op,
acceptance_function,
decomposer=decomposer,
max_expansion=max_expansion,
current_depth=current_depth,
name=name,
error=error,
)


#######################


Expand Down Expand Up @@ -132,7 +101,7 @@ def no_analytic(

@transform
def validate_device_wires(
tape: QuantumScript, wires: qml.wires.Wires | None = None, name: str = "device"
tape: QuantumScript, wires: Wires | None = None, name: str = "device"
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
"""Validates that all wires present in the tape are in the set of provided wires. Adds the
device wires to measurement processes like :class:`~.measurements.StateMP` that are broadcasted
Expand Down Expand Up @@ -321,12 +290,12 @@ def validate_adjoint_trainable_params(
@transform
def decompose( # pylint: disable = too-many-positional-arguments
tape: QuantumScript,
stopping_condition: Callable[[qml.operation.Operator], bool],
stopping_condition_shots: Callable[[qml.operation.Operator], bool] = None,
stopping_condition: Callable[[Operator], bool],
stopping_condition_shots: Callable[[Operator], bool] | None = None,
skip_initial_state_prep: bool = True,
decomposer: None | (
Callable[[qml.operation.Operator], Sequence[qml.operation.Operator]]
) = None,
decomposer: Callable[[Operator], Sequence[Operator]] | None = None,
graph_solution: DecompGraphSolution | None = None,
num_available_work_wires: int | None = 0,
name: str = "device",
error: type[Exception] | None = None,
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
Expand Down Expand Up @@ -403,13 +372,7 @@ def decompose( # pylint: disable = too-many-positional-arguments

"""

if error is None:
error = DeviceError

if decomposer is None:

def decomposer(op):
return op.decomposition()
error = error or DeviceError

if stopping_condition_shots is not None and tape.shots:
stopping_condition = stopping_condition_shots
Expand All @@ -421,17 +384,18 @@ def decomposer(op):

if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]):
return (tape,), null_postprocessing
try:

try:
new_ops = [
final_op
for op in tape.operations[len(prep_op) :]
for final_op in _operator_decomposition_gen(
op,
stopping_condition,
decomposer=decomposer,
name=name,
error=error,
num_available_work_wires=num_available_work_wires,
graph_solution=graph_solution,
custom_decomposer=decomposer,
strict=True,
)
]
except RecursionError as e:
Expand All @@ -440,6 +404,10 @@ def decomposer(op):
"Operator decomposition may have entered an infinite loop."
) from e

except DecompositionUndefinedError as e:
message = str(e).replace("not supported", f"not supported with {name}")
raise error(message) from e

tape = tape.copy(operations=prep_op + new_ops)

return (tape,), null_postprocessing
Expand All @@ -448,8 +416,8 @@ def decomposer(op):
@transform
def validate_observables(
tape: QuantumScript,
stopping_condition: Callable[[qml.operation.Operator], bool],
stopping_condition_shots: Callable[[qml.operation.Operator], bool] = None,
stopping_condition: Callable[[Operator], bool],
stopping_condition_shots: Callable[[Operator], bool] | None = None,
name: str = "device",
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
"""Validates the observables and measurements for a circuit.
Expand Down
43 changes: 40 additions & 3 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pennylane.decomposition import DecompositionGraph, enabled_graph
from pennylane.decomposition.decomposition_graph import DecompGraphSolution
from pennylane.decomposition.utils import translate_op_alias
from pennylane.exceptions import DecompositionUndefinedError
from pennylane.operation import Operator
from pennylane.ops import Conditional, GlobalPhase
from pennylane.transforms.core import transform
Expand Down Expand Up @@ -791,15 +792,34 @@ def circuit():
return (tape,), null_postprocessing


def _operator_decomposition_gen( # pylint: disable=too-many-arguments
def _operator_decomposition_gen( # pylint: disable=too-many-arguments,too-many-branches
op: Operator,
acceptance_function: Callable[[Operator], bool],
max_expansion: int | None = None,
current_depth=0,
current_depth: int = 0,
num_available_work_wires: int | None = 0,
graph_solution: DecompGraphSolution | None = None,
custom_decomposer: Callable[[Operator], Sequence[Operator]] | None = None,
strict: bool = False,
) -> Generator[Operator]:
"""A generator that yields the next operation that is accepted."""
"""A generator that yields the next operation that is accepted.

Args:
op: The operator to decompose
acceptance_function: Returns True if the operator does not need further decomposition.
max_expansion: The maximum level of expansion.
current_depth: The current depth of expansion.
num_available_work_wires: The number of available work wires at the top level.
graph_solution: The solution to the decomposition graph.
custom_decomposer: A custom function that decomposes an operator. This is only relevant
with the graph enabled, and only used by ``preprocess.decompose``.
strict: If True, an error will be raised when an operator does not provide a decomposition
and does not meet the stopping criteria.

Returns:
A generator of Operators

"""

max_depth_reached = False
decomp = []
Expand All @@ -822,6 +842,8 @@ def _operator_decomposition_gen( # pylint: disable=too-many-arguments
max_expansion=max_expansion,
current_depth=current_depth,
graph_solution=graph_solution,
custom_decomposer=custom_decomposer,
strict=strict,
)
)

Expand All @@ -846,9 +868,22 @@ def _operator_decomposition_gen( # pylint: disable=too-many-arguments
)
yield op

elif custom_decomposer is not None:
try:
decomp = custom_decomposer(op)
except DecompositionUndefinedError as e:
raise DecompositionUndefinedError(
f"Operator {op} not supported and does not provide a decomposition."
) from e

elif op.has_decomposition:
decomp = op.decomposition()

elif strict:
raise DecompositionUndefinedError(
f"Operator {op} not supported and does not provide a decomposition."
)

else:
warnings.warn(
f"Operator {op.name} does not define a decomposition to the target gate set and was not found in the "
Expand All @@ -867,6 +902,8 @@ def _operator_decomposition_gen( # pylint: disable=too-many-arguments
current_depth=current_depth,
num_available_work_wires=num_available_work_wires,
graph_solution=graph_solution,
custom_decomposer=custom_decomposer,
strict=strict,
)


Expand Down
31 changes: 18 additions & 13 deletions tests/devices/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def test_operator_decomposition_gen_accepted_operator(self, op):
def stopping_condition(op):
return op.has_matrix

casted_to_list = list(_operator_decomposition_gen(op, stopping_condition, self.decomposer))
casted_to_list = list(
_operator_decomposition_gen(op, stopping_condition, custom_decomposer=self.decomposer)
)
assert len(casted_to_list) == 1
assert casted_to_list[0] is op

Expand All @@ -100,7 +102,9 @@ def stopping_condition(op):
return op.has_matrix

op = NoMatOp("a")
casted_to_list = list(_operator_decomposition_gen(op, stopping_condition, self.decomposer))
casted_to_list = list(
_operator_decomposition_gen(op, stopping_condition, custom_decomposer=self.decomposer)
)
assert len(casted_to_list) == 2
qml.assert_equal(casted_to_list[0], qml.PauliX("a"))
qml.assert_equal(casted_to_list[1], qml.PauliY("a"))
Expand All @@ -120,24 +124,16 @@ def decomposition(self):
return [NoMatOp(self.wires), qml.S(self.wires), qml.adjoint(NoMatOp(self.wires))]

op = RaggedDecompositionOp("a")
final_decomp = list(_operator_decomposition_gen(op, stopping_condition, self.decomposer))
final_decomp = list(
_operator_decomposition_gen(op, stopping_condition, custom_decomposer=self.decomposer)
)
assert len(final_decomp) == 5
qml.assert_equal(final_decomp[0], qml.PauliX("a"))
qml.assert_equal(final_decomp[1], qml.PauliY("a"))
qml.assert_equal(final_decomp[2], qml.S("a"))
qml.assert_equal(final_decomp[3], qml.adjoint(qml.PauliY("a")))
qml.assert_equal(final_decomp[4], qml.adjoint(qml.PauliX("a")))

def test_error_from_unsupported_operation(self):
"""Test that a device error is raised if the operator cant be decomposed and doesn't have a matrix."""
op = NoMatNoDecompOp("a")
with pytest.raises(DeviceError, match=r"not supported with abc and does"):
tuple(
_operator_decomposition_gen(
op, lambda op: op.has_matrix, self.decomposer, name="abc"
)
)


def test_no_sampling():
"""Tests for the no_sampling transform."""
Expand Down Expand Up @@ -263,6 +259,15 @@ def test_error_if_invalid_op(self):
with pytest.raises(DeviceError, match="not supported with abc"):
decompose(tape, lambda op: op.has_matrix, name="abc")

def test_error_if_invalid_op_decomposer(self):
"""Test that expand_fn throws an error when an operation does not define a matrix or decomposition."""

tape = QuantumScript(ops=[NoMatNoDecompOp(0)], measurements=[qml.expval(qml.Hadamard(0))])
with pytest.raises(DeviceError, match="not supported with abc"):
decompose(
tape, lambda op: op.has_matrix, decomposer=lambda op: op.decomposition(), name="abc"
)

def test_decompose(self):
"""Test that expand_fn doesn't throw any errors for a valid circuit"""
tape = QuantumScript(
Expand Down