Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,13 @@
* A :class:`~.decomposition.decomposition_graph.DecompGraphSolution` class is added to store the solution of a decomposition graph. An instance of this class is returned from the `solve` method of the :class:`~.decomposition.decomposition_graph.DecompositionGraph`.
[(#8031)](https://github.com/PennyLaneAI/pennylane/pull/8031)

* With the graph-based decomposition system enabled (:func:`~.decomposition.enable_graph()`), if a decomposition cannot be found for an operator in the circuit, it no longer
raises an error. Instead, a warning is raised, and `op.decomposition()` (the current default method for decomposing gates) is
used as a fallback, while the rest of the circuit is still decomposed with
the new graph-based system. Additionally, a special warning message is
raised if the circuit contains a `GlobalPhase`, reminding the user that
`GlobalPhase` is not assumed to have a decomposition under the new system.
[(#8156)](https://github.com/PennyLaneAI/pennylane/pull/8156)
<h3>Labs: a place for unified and rapid prototyping of research software 🧪</h3>

* Added state of the art resources for the `ResourceSelectPauliRot` template and the
Expand Down
8 changes: 6 additions & 2 deletions pennylane/decomposition/decomposition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from __future__ import annotations

import warnings
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, replace
Expand Down Expand Up @@ -488,8 +489,11 @@ def solve(self, num_work_wires: int | None = 0, lazy=True) -> DecompGraphSolutio
if visitor.unsolved_op_indices:
unsolved_ops = [self._graph[op_idx] for op_idx in visitor.unsolved_op_indices]
op_names = {op_node.op.name for op_node in unsolved_ops}
raise DecompositionError(
f"Decomposition not found for {op_names} to the gate set {set(self._gate_set_weights)}"
warnings.warn(
f"The graph-based decomposition system is unable to find a decomposition for "
f"{op_names} to the target gate set {set(self._gate_set_weights)}. The default "
"decomposition for these operators will be used instead.",
UserWarning,
)
return DecompGraphSolution(visitor, self._all_op_indices, self._op_to_op_nodes)

Expand Down
5 changes: 1 addition & 4 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,10 +1446,7 @@ def resource_params(self) -> dict:
{"num_wires": 2}

"""
# For most operators, this should just be an empty dictionary, but a default
# implementation is intentionally not provided so that each operator class is
# forced to explicitly define its resource params.
raise NotImplementedError(f"{self.__class__.__name__}.resource_params undefined!")
return {}

# pylint: disable=no-self-argument, comparison-with-callable
@classproperty
Expand Down
97 changes: 34 additions & 63 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pennylane.decomposition.decomposition_graph import DecompGraphSolution
from pennylane.decomposition.utils import translate_op_alias
from pennylane.operation import Operator
from pennylane.ops import Conditional
from pennylane.ops import Conditional, GlobalPhase
from pennylane.transforms.core import transform


Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(

gate_set, stopping_condition = _resolve_gate_set(gate_set, stopping_condition)
self._gate_set = gate_set
self._stopping_condition = stopping_condition
self.stopping_condition = stopping_condition

def setup(self) -> None:
"""Setup the environment for the interpreter by pushing a new environment frame."""
Expand All @@ -135,34 +135,6 @@ def read(self, var):
"""Extract the value corresponding to a variable."""
return var.val if isinstance(var, jax.extend.core.Literal) else self._env_map[var]

def stopping_condition(self, op: Operator) -> bool:
"""Function to determine whether an operator needs to be decomposed or not.

Args:
op (Operator): Operator to check.

Returns:
bool: Whether ``op`` is valid or needs to be decomposed. ``True`` means
that the operator does not need to be decomposed.
"""

# If the new graph-based decomposition is enabled,
# we don't rely on the has_decomposition attribute.
if enabled_graph():
return self._stopping_condition(op)

if not op.has_decomposition:
if not self._stopping_condition(op):
warnings.warn(
f"Operator {op.name} does not define a decomposition and was not "
f"found in the target gate set. To remove this warning, add the operator "
f"name ({op.name}) or type ({type(op)}) to the gate set.",
UserWarning,
)
return True

return self._stopping_condition(op)

def decompose_operation(self, op: Operator):
"""Decompose a PennyLane operation instance if it does not satisfy the
provided gate set.
Expand All @@ -176,7 +148,7 @@ def decompose_operation(self, op: Operator):
See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`.
"""

if self._stopping_condition(op):
if self.stopping_condition(op):
return self.interpret_operation(op)

max_expansion = (
Expand All @@ -198,10 +170,10 @@ def decompose_operation(self, op: Operator):
def _evaluate_jaxpr_decomposition(self, op: Operator):
"""Creates and evaluates a Jaxpr of the plxpr decomposition of an operator."""

if self._stopping_condition(op):
if self.max_expansion is not None and self._current_depth >= self.max_expansion:
return self.interpret_operation(op)

if self.max_expansion is not None and self._current_depth >= self.max_expansion:
if self.stopping_condition(op):
return self.interpret_operation(op)

if self._decomp_graph_solution and self._decomp_graph_solution.is_solved_for(op):
Expand Down Expand Up @@ -763,26 +735,7 @@ def circuit():

gate_set, stopping_condition = _resolve_gate_set(gate_set, stopping_condition)

def _stopping_condition(op):

# If the new graph-based decomposition is enabled,
# we don't rely on the has_decomposition attribute.
if enabled_graph():
return stopping_condition(op)

if not op.has_decomposition:
if not stopping_condition(op):
warnings.warn(
f"Operator {op.name} does not define a decomposition and was not "
f"found in the target gate set. To remove this warning, add the operator name "
f"({op.name}) or type ({type(op)}) to the gate set.",
UserWarning,
)
return True

return stopping_condition(op)

if all(_stopping_condition(op) for op in tape.operations):
if all(stopping_condition(op) for op in tape.operations):
return (tape,), null_postprocessing

# If the decomposition graph is enabled, we create a DecompositionGraph instance
Expand All @@ -805,7 +758,7 @@ def _stopping_condition(op):
for op in tape.operations
for final_op in _operator_decomposition_gen(
op,
_stopping_condition,
stopping_condition,
max_expansion=max_expansion,
num_available_work_wires=num_available_work_wires,
graph_solution=decomp_graph_solution,
Expand Down Expand Up @@ -839,8 +792,10 @@ def _operator_decomposition_gen( # pylint: disable=too-many-arguments
if max_expansion is not None and max_expansion <= current_depth:
max_depth_reached = True

# Handle classically controlled operators
if isinstance(op, Conditional):
if isinstance(op, (Allocate, Deallocate)):
yield op

elif isinstance(op, Conditional):
if acceptance_function(op.base) or max_depth_reached:
yield op
else:
Expand All @@ -858,21 +813,37 @@ def _operator_decomposition_gen( # pylint: disable=too-many-arguments
elif acceptance_function(op) or max_depth_reached:
yield op

elif isinstance(op, (Allocate, Deallocate)):
yield op

elif graph_solution is not None and graph_solution.is_solved_for(op, num_available_work_wires):
elif graph_solution and graph_solution.is_solved_for(op, num_available_work_wires):
op_rule = graph_solution.decomposition(op, num_available_work_wires)
with queuing.AnnotatedQueue() as decomposed_ops:
op_rule(*op.parameters, wires=op.wires, **op.hyperparameters)
decomp = decomposed_ops.queue
current_depth += 1
if num_available_work_wires is not None:
num_available_work_wires -= op_rule.get_work_wire_spec(**op.resource_params).total
else:

elif enabled_graph() and isinstance(op, GlobalPhase):
warnings.warn(
"With qml.decomposition.enabled_graph(), GlobalPhase is not assumed to have a "
"decomposition. To disable this warning, add `GlobalPhase` to the gate set, or "
"assign a decomposition rule to `GlobalPhase` via the `fixed_decomps` keyword "
"argument. To make GlobalPhase decompose to nothing, you can import `null_decomp` "
"from pennylane.decomposition.decomposition_rule, and assign it to GlobalPhase."
)
yield op

elif op.has_decomposition:
decomp = op.decomposition()
current_depth += 1

else:
warnings.warn(
f"Operator {op.name} does not define a decomposition to the target gate set and was not found in the "
f"target gate set. To remove this warning, add the operator name ({op.name}) or "
f"type ({type(op)}) to the gate set.",
UserWarning,
)
yield op

current_depth += 1
for sub_op in decomp:
yield from _operator_decomposition_gen(
sub_op,
Expand Down
19 changes: 4 additions & 15 deletions tests/capture/transforms/test_capture_decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_init(self, gate_set, max_expansion):
assert interpreter.max_expansion == max_expansion
valid_op = qml.RX(1.5, 0)
invalid_op = qml.RY(1.5, 0)
assert interpreter._stopping_condition(valid_op)
assert not interpreter._stopping_condition(invalid_op)
assert interpreter.stopping_condition(valid_op)
assert not interpreter.stopping_condition(invalid_op)

@pytest.mark.unit
def test_fixed_alt_decomps_not_available_capture(self):
Expand All @@ -73,23 +73,12 @@ def my_cnot(*_, **__):
DecomposeInterpreter(alt_decomps={qml.CNOT: [my_cnot]})

@pytest.mark.parametrize("op", [qml.RX(1.5, 0), qml.RZ(1.5, 0)])
def test_stopping_condition(self, op, recwarn):
def test_stopping_condition(self, op):
"""Test that stopping_condition works correctly."""
# pylint: disable=unnecessary-lambda-assignment
gate_set = lambda op: op.name == "RX"
interpreter = DecomposeInterpreter(gate_set=gate_set)

if gate_set(op):
assert interpreter.stopping_condition(op)
assert len(recwarn) == 0

else:
if not op.has_decomposition:
with pytest.warns(UserWarning, match="does not define a decomposition"):
assert interpreter.stopping_condition(op)
else:
assert not interpreter.stopping_condition(op)
assert len(recwarn) == 0
assert interpreter.stopping_condition(op) == gate_set(op)

def test_decompose_simple(self):
"""Test that a simple function can be decomposed correctly."""
Expand Down
8 changes: 4 additions & 4 deletions tests/capture/transforms/test_capture_graph_decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def test_gate_set_contains(self):
"""Tests specifying the target gate set."""

interpreter = DecomposeInterpreter(gate_set={qml.RX, "RZ", "CNOT"})
assert interpreter._stopping_condition(qml.RX(1.5, 0))
assert interpreter._stopping_condition(qml.RZ(1.5, 0))
assert interpreter._stopping_condition(qml.CNOT(wires=[0, 1]))
assert not interpreter._stopping_condition(qml.Hadamard(0))
assert interpreter.stopping_condition(qml.RX(1.5, 0))
assert interpreter.stopping_condition(qml.RZ(1.5, 0))
assert interpreter.stopping_condition(qml.CNOT(wires=[0, 1]))
assert not interpreter.stopping_condition(qml.Hadamard(0))

@pytest.mark.unit
def test_callable_gate_set_not_supported(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/decomposition/test_decomposition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_decomposition_not_found(self, _):

op = qml.Hadamard(wires=[0])
graph = DecompositionGraph(operations=[op], gate_set={"RX", "RY", "GlobalPhase"})
with pytest.raises(DecompositionError, match="Decomposition not found for {'Hadamard'}"):
with pytest.warns(UserWarning, match="unable to find a decomposition for {'Hadamard'}"):
graph.solve()

def test_lazy_solve(self, _):
Expand Down
28 changes: 23 additions & 5 deletions tests/transforms/test_decompose_transform_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,34 @@ def resource_params(self):
def decomposition(self):
return [qml.H(self.wires[1]), qml.CNOT(self.wires), qml.H(self.wires[1])]

@qml.register_resources({qml.CZ: 1})
@qml.register_resources({qml.CRZ: 1})
def my_decomp(wires, **__):
qml.CZ(wires=wires)
qml.CRZ(np.pi, wires=wires)

tape = qml.tape.QuantumScript([CustomOp(wires=[0, 1])])
[new_tape], _ = qml.transforms.decompose(
tape, gate_set={"CNOT", "Hadamard"}, fixed_decomps={CustomOp: my_decomp}
)

with pytest.warns(UserWarning, match="The graph-based decomposition system is unable"):
[new_tape], _ = qml.transforms.decompose(
[tape],
gate_set={"CNOT", "Hadamard"},
fixed_decomps={CustomOp: my_decomp},
)

assert new_tape.operations == [qml.H(1), qml.CNOT(wires=[0, 1]), qml.H(1)]

@pytest.mark.integration
def test_global_phase_warning(self):
"""Tests that a sensible warning is raised when the graph fails to find a solution
due to GlobalPhase not being part of the gate set."""

tape = qml.tape.QuantumScript([qml.X(0)])

with pytest.warns(UserWarning, match="GlobalPhase is not assumed"):
with pytest.warns(UserWarning, match="The graph-based decomposition system is unable"):
[new_tape], _ = qml.transforms.decompose([tape], gate_set={"RX"})

assert new_tape.operations == [qml.RX(np.pi, wires=0), qml.GlobalPhase(-np.pi / 2, wires=0)]

@pytest.mark.integration
def test_controlled_decomp(self):
"""Tests decomposing a controlled operation."""
Expand Down