Skip to content

Commit b51f590

Browse files
astralcaiJerryChen97isaacdevlugtPietropaoloFrisoni
authored
Fall back to op.decompose if op is unsolved in decomposition graph (#8156)
**Context:** The DecompositionGraph currently errors out completely if any of the operators is unsolved for, and as a result, with graph enabled, if a circuit contains a single operator/template that is not yet integrated with the new decomposition system, the entire thing fails, and the user is forced to turn off graph mode and use the old decomposition system. We want to change this behaviour, make it so that the DecompositionGraph raise a warning instead of an error if certain operators are unsolved for, and fall back to using op.decomposition for those operators. **Description of the Change:** - Downgrade error to a warning when certain operators are unsolved in the graph. - Add sensible warning for when the graph failed to find a solution due to GlobalPhase. - When the new system is enabled, operators without a decomposition will be left undecomposed with a warning. **Benefits:** - The user can use the new system for the rest of the circuit even if it contains templates that are not yet integrated with the new graph-based system. **Possible Drawbacks:** **Related GitHub Issues:** [sc-98390] --------- Co-authored-by: Yushao Chen (Jerry) <[email protected]> Co-authored-by: Isaac De Vlugt <[email protected]> Co-authored-by: Pietropaolo Frisoni <[email protected]>
1 parent d6fe071 commit b51f590

File tree

10 files changed

+141
-99
lines changed

10 files changed

+141
-99
lines changed

doc/releases/changelog-dev.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,13 @@
502502
* A :class:`~.decomposition.decomposition_graph.DecompGraphSolution` class is added to store the solution of a decomposition graph. An instance of this class is returned from the `solve` method of the :class:`~.decomposition.decomposition_graph.DecompositionGraph`.
503503
[(#8031)](https://github.com/PennyLaneAI/pennylane/pull/8031)
504504

505+
* With the graph-based decomposition system enabled (:func:`~.decomposition.enable_graph()`), if a decomposition cannot be found for an operator in the circuit, it no longer
506+
raises an error. Instead, a warning is raised, and `op.decomposition()` (the current default method for decomposing gates) is
507+
used as a fallback, while the rest of the circuit is still decomposed with
508+
the new graph-based system. Additionally, a special warning message is
509+
raised if the circuit contains a `GlobalPhase`, reminding the user that
510+
`GlobalPhase` is not assumed to have a decomposition under the new system.
511+
[(#8156)](https://github.com/PennyLaneAI/pennylane/pull/8156)
505512
<h3>Labs: a place for unified and rapid prototyping of research software 🧪</h3>
506513

507514
* The module `qml.labs.zxopt` has been removed as its functionalities are now available in the

pennylane/decomposition/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
~DecompositionRule
6767
~Resources
6868
~CompressedResourceOp
69+
~null_decomp
6970
7071
In the new decomposition system, a decomposition rule must be defined as a quantum function that
7172
accepts ``(*op.parameters, op.wires, **op.hyperparameters)`` as arguments, where ``op`` is an
@@ -209,6 +210,7 @@ def circuit():
209210
:toctree: api
210211
211212
~DecompositionGraph
213+
~DecompGraphSolution
212214
213215
The decomposition graph is a directed graph of operators and decomposition rules. Dijkstra's
214216
algorithm is used to explore the graph and find the most efficient decomposition of a given
@@ -221,12 +223,12 @@ def circuit():
221223
operations=[op],
222224
gate_set={"RZ", "RX", "CNOT", "GlobalPhase"},
223225
)
224-
graph.solve()
226+
solution = graph.solve()
225227
226228
.. code-block:: pycon
227229
228230
>>> with qml.queuing.AnnotatedQueue() as q:
229-
... graph.decomposition(op)(0.5, wires=[0, 1])
231+
... solution.decomposition(op)(0.5, wires=[0, 1])
230232
...
231233
>>> q.queue
232234
[RZ(1.5707963267948966, wires=[1]),
@@ -254,7 +256,7 @@ def circuit():
254256
disable_graph,
255257
enabled_graph,
256258
)
257-
from .decomposition_graph import DecompositionGraph
259+
from .decomposition_graph import DecompositionGraph, DecompGraphSolution
258260
from .resources import (
259261
Resources,
260262
resource_rep,
@@ -267,6 +269,7 @@ def circuit():
267269
register_resources,
268270
register_condition,
269271
DecompositionRule,
272+
null_decomp,
270273
add_decomps,
271274
list_decomps,
272275
has_decomp,

pennylane/decomposition/decomposition_graph.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from __future__ import annotations
2727

28+
import warnings
2829
from collections import defaultdict
2930
from collections.abc import Iterable
3031
from dataclasses import dataclass, replace
@@ -488,14 +489,46 @@ def solve(self, num_work_wires: int | None = 0, lazy=True) -> DecompGraphSolutio
488489
if visitor.unsolved_op_indices:
489490
unsolved_ops = [self._graph[op_idx] for op_idx in visitor.unsolved_op_indices]
490491
op_names = {op_node.op.name for op_node in unsolved_ops}
491-
raise DecompositionError(
492-
f"Decomposition not found for {op_names} to the gate set {set(self._gate_set_weights)}"
492+
warnings.warn(
493+
f"The graph-based decomposition system is unable to find a decomposition for "
494+
f"{op_names} to the target gate set {set(self._gate_set_weights)}. The default "
495+
"decomposition (op.decomposition()) for these operators will be used instead.",
496+
UserWarning,
493497
)
494498
return DecompGraphSolution(visitor, self._all_op_indices, self._op_to_op_nodes)
495499

496500

497501
class DecompGraphSolution:
498-
"""A solution to a decomposition graph."""
502+
"""A solution to a decomposition graph.
503+
504+
An instance of this class is returned from :meth:`DecompositionGraph.solve`
505+
506+
**Example**
507+
508+
.. code-block:: python
509+
510+
from pennylane.decomposition import DecompositionGraph
511+
512+
op = qml.CRX(0.5, wires=[0, 1])
513+
graph = DecompositionGraph(
514+
operations=[op],
515+
gate_set={"RZ", "RX", "CNOT", "GlobalPhase"},
516+
)
517+
solution = graph.solve()
518+
519+
>>> with qml.queuing.AnnotatedQueue() as q:
520+
... solution.decomposition(op)(0.5, wires=[0, 1])
521+
>>> q.queue
522+
[RZ(1.5707963267948966, wires=[1]),
523+
RY(0.25, wires=[1]),
524+
CNOT(wires=[0, 1]),
525+
RY(-0.25, wires=[1]),
526+
CNOT(wires=[0, 1]),
527+
RZ(-1.5707963267948966, wires=[1])]
528+
>>> solution.resource_estimate(op)
529+
<num_gates=10, gate_counts={RZ: 6, CNOT: 2, RX: 2}, weighted_cost=10.0>
530+
531+
"""
499532

500533
def __init__(
501534
self,

pennylane/decomposition/decomposition_rule.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,5 +602,29 @@ def has_decomp(op_type: type[Operator] | str) -> bool:
602602

603603
@register_resources({})
604604
def null_decomp(*_, **__):
605-
"""A decomposition rule that does nothing."""
605+
"""A decomposition rule that can be assigned to an operator so that the operator decomposes to nothing.
606+
607+
**Example**
608+
609+
.. code-block:: python
610+
611+
from functools import partial
612+
import pennylane as qml
613+
from pennylane.decomposition import null_decomp
614+
615+
qml.decomposition.enable_graph()
616+
617+
@partial(
618+
qml.transforms.decompose,
619+
gate_set={qml.RZ},
620+
fixed_decomps={qml.GlobalPhase: null_decomp}
621+
)
622+
@qml.qnode(qml.device("default.qubit"))
623+
def circuit():
624+
qml.Z(0)
625+
626+
>>> print(qml.draw(circuit)())
627+
0: ──RZ(3.14)─┤
628+
629+
"""
606630
return

pennylane/operation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,10 +1446,7 @@ def resource_params(self) -> dict:
14461446
{"num_wires": 2}
14471447
14481448
"""
1449-
# For most operators, this should just be an empty dictionary, but a default
1450-
# implementation is intentionally not provided so that each operator class is
1451-
# forced to explicitly define its resource params.
1452-
raise NotImplementedError(f"{self.__class__.__name__}.resource_params undefined!")
1449+
return {}
14531450

14541451
# pylint: disable=no-self-argument, comparison-with-callable
14551452
@classproperty

pennylane/transforms/decompose.py

Lines changed: 34 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pennylane.decomposition.decomposition_graph import DecompGraphSolution
3030
from pennylane.decomposition.utils import translate_op_alias
3131
from pennylane.operation import Operator
32-
from pennylane.ops import Conditional
32+
from pennylane.ops import Conditional, GlobalPhase
3333
from pennylane.transforms.core import transform
3434

3535

@@ -115,7 +115,7 @@ def __init__(
115115

116116
gate_set, stopping_condition = _resolve_gate_set(gate_set, stopping_condition)
117117
self._gate_set = gate_set
118-
self._stopping_condition = stopping_condition
118+
self.stopping_condition = stopping_condition
119119

120120
def setup(self) -> None:
121121
"""Setup the environment for the interpreter by pushing a new environment frame."""
@@ -135,34 +135,6 @@ def read(self, var):
135135
"""Extract the value corresponding to a variable."""
136136
return var.val if isinstance(var, jax.extend.core.Literal) else self._env_map[var]
137137

138-
def stopping_condition(self, op: Operator) -> bool:
139-
"""Function to determine whether an operator needs to be decomposed or not.
140-
141-
Args:
142-
op (Operator): Operator to check.
143-
144-
Returns:
145-
bool: Whether ``op`` is valid or needs to be decomposed. ``True`` means
146-
that the operator does not need to be decomposed.
147-
"""
148-
149-
# If the new graph-based decomposition is enabled,
150-
# we don't rely on the has_decomposition attribute.
151-
if enabled_graph():
152-
return self._stopping_condition(op)
153-
154-
if not op.has_decomposition:
155-
if not self._stopping_condition(op):
156-
warnings.warn(
157-
f"Operator {op.name} does not define a decomposition and was not "
158-
f"found in the target gate set. To remove this warning, add the operator "
159-
f"name ({op.name}) or type ({type(op)}) to the gate set.",
160-
UserWarning,
161-
)
162-
return True
163-
164-
return self._stopping_condition(op)
165-
166138
def decompose_operation(self, op: Operator):
167139
"""Decompose a PennyLane operation instance if it does not satisfy the
168140
provided gate set.
@@ -176,7 +148,7 @@ def decompose_operation(self, op: Operator):
176148
See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`.
177149
"""
178150

179-
if self._stopping_condition(op):
151+
if self.stopping_condition(op):
180152
return self.interpret_operation(op)
181153

182154
max_expansion = (
@@ -198,10 +170,10 @@ def decompose_operation(self, op: Operator):
198170
def _evaluate_jaxpr_decomposition(self, op: Operator):
199171
"""Creates and evaluates a Jaxpr of the plxpr decomposition of an operator."""
200172

201-
if self._stopping_condition(op):
173+
if self.max_expansion is not None and self._current_depth >= self.max_expansion:
202174
return self.interpret_operation(op)
203175

204-
if self.max_expansion is not None and self._current_depth >= self.max_expansion:
176+
if self.stopping_condition(op):
205177
return self.interpret_operation(op)
206178

207179
if self._decomp_graph_solution and self._decomp_graph_solution.is_solved_for(op):
@@ -763,26 +735,7 @@ def circuit():
763735

764736
gate_set, stopping_condition = _resolve_gate_set(gate_set, stopping_condition)
765737

766-
def _stopping_condition(op):
767-
768-
# If the new graph-based decomposition is enabled,
769-
# we don't rely on the has_decomposition attribute.
770-
if enabled_graph():
771-
return stopping_condition(op)
772-
773-
if not op.has_decomposition:
774-
if not stopping_condition(op):
775-
warnings.warn(
776-
f"Operator {op.name} does not define a decomposition and was not "
777-
f"found in the target gate set. To remove this warning, add the operator name "
778-
f"({op.name}) or type ({type(op)}) to the gate set.",
779-
UserWarning,
780-
)
781-
return True
782-
783-
return stopping_condition(op)
784-
785-
if all(_stopping_condition(op) for op in tape.operations):
738+
if all(stopping_condition(op) for op in tape.operations):
786739
return (tape,), null_postprocessing
787740

788741
# If the decomposition graph is enabled, we create a DecompositionGraph instance
@@ -805,7 +758,7 @@ def _stopping_condition(op):
805758
for op in tape.operations
806759
for final_op in _operator_decomposition_gen(
807760
op,
808-
_stopping_condition,
761+
stopping_condition,
809762
max_expansion=max_expansion,
810763
num_available_work_wires=num_available_work_wires,
811764
graph_solution=decomp_graph_solution,
@@ -839,8 +792,10 @@ def _operator_decomposition_gen( # pylint: disable=too-many-arguments
839792
if max_expansion is not None and max_expansion <= current_depth:
840793
max_depth_reached = True
841794

842-
# Handle classically controlled operators
843-
if isinstance(op, Conditional):
795+
if isinstance(op, (Allocate, Deallocate)):
796+
yield op
797+
798+
elif isinstance(op, Conditional):
844799
if acceptance_function(op.base) or max_depth_reached:
845800
yield op
846801
else:
@@ -858,21 +813,37 @@ def _operator_decomposition_gen( # pylint: disable=too-many-arguments
858813
elif acceptance_function(op) or max_depth_reached:
859814
yield op
860815

861-
elif isinstance(op, (Allocate, Deallocate)):
862-
yield op
863-
864-
elif graph_solution is not None and graph_solution.is_solved_for(op, num_available_work_wires):
816+
elif graph_solution and graph_solution.is_solved_for(op, num_available_work_wires):
865817
op_rule = graph_solution.decomposition(op, num_available_work_wires)
866818
with queuing.AnnotatedQueue() as decomposed_ops:
867819
op_rule(*op.parameters, wires=op.wires, **op.hyperparameters)
868820
decomp = decomposed_ops.queue
869-
current_depth += 1
870821
if num_available_work_wires is not None:
871822
num_available_work_wires -= op_rule.get_work_wire_spec(**op.resource_params).total
872-
else:
823+
824+
elif enabled_graph() and isinstance(op, GlobalPhase):
825+
warnings.warn(
826+
"With qml.decomposition.enabled_graph(), GlobalPhase is not assumed to have a "
827+
"decomposition. To disable this warning, add GlobalPhase to the gate set, or "
828+
"assign a decomposition rule to GlobalPhase via the fixed_decomps keyword "
829+
"argument. To make GlobalPhase decompose to nothing, you can import null_decomp "
830+
"from pennylane.decomposition, and assign it to GlobalPhase."
831+
)
832+
yield op
833+
834+
elif op.has_decomposition:
873835
decomp = op.decomposition()
874-
current_depth += 1
875836

837+
else:
838+
warnings.warn(
839+
f"Operator {op.name} does not define a decomposition to the target gate set and was not found in the "
840+
f"target gate set. To remove this warning, add the operator name ({op.name}) or "
841+
f"type ({type(op)}) to the gate set.",
842+
UserWarning,
843+
)
844+
yield op
845+
846+
current_depth += 1
876847
for sub_op in decomp:
877848
yield from _operator_decomposition_gen(
878849
sub_op,

tests/capture/transforms/test_capture_decompose.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def test_init(self, gate_set, max_expansion):
5454
assert interpreter.max_expansion == max_expansion
5555
valid_op = qml.RX(1.5, 0)
5656
invalid_op = qml.RY(1.5, 0)
57-
assert interpreter._stopping_condition(valid_op)
58-
assert not interpreter._stopping_condition(invalid_op)
57+
assert interpreter.stopping_condition(valid_op)
58+
assert not interpreter.stopping_condition(invalid_op)
5959

6060
@pytest.mark.unit
6161
def test_fixed_alt_decomps_not_available_capture(self):
@@ -73,23 +73,12 @@ def my_cnot(*_, **__):
7373
DecomposeInterpreter(alt_decomps={qml.CNOT: [my_cnot]})
7474

7575
@pytest.mark.parametrize("op", [qml.RX(1.5, 0), qml.RZ(1.5, 0)])
76-
def test_stopping_condition(self, op, recwarn):
76+
def test_stopping_condition(self, op):
7777
"""Test that stopping_condition works correctly."""
7878
# pylint: disable=unnecessary-lambda-assignment
7979
gate_set = lambda op: op.name == "RX"
8080
interpreter = DecomposeInterpreter(gate_set=gate_set)
81-
82-
if gate_set(op):
83-
assert interpreter.stopping_condition(op)
84-
assert len(recwarn) == 0
85-
86-
else:
87-
if not op.has_decomposition:
88-
with pytest.warns(UserWarning, match="does not define a decomposition"):
89-
assert interpreter.stopping_condition(op)
90-
else:
91-
assert not interpreter.stopping_condition(op)
92-
assert len(recwarn) == 0
81+
assert interpreter.stopping_condition(op) == gate_set(op)
9382

9483
def test_decompose_simple(self):
9584
"""Test that a simple function can be decomposed correctly."""

tests/capture/transforms/test_capture_graph_decompose.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def test_gate_set_contains(self):
3939
"""Tests specifying the target gate set."""
4040

4141
interpreter = DecomposeInterpreter(gate_set={qml.RX, "RZ", "CNOT"})
42-
assert interpreter._stopping_condition(qml.RX(1.5, 0))
43-
assert interpreter._stopping_condition(qml.RZ(1.5, 0))
44-
assert interpreter._stopping_condition(qml.CNOT(wires=[0, 1]))
45-
assert not interpreter._stopping_condition(qml.Hadamard(0))
42+
assert interpreter.stopping_condition(qml.RX(1.5, 0))
43+
assert interpreter.stopping_condition(qml.RZ(1.5, 0))
44+
assert interpreter.stopping_condition(qml.CNOT(wires=[0, 1]))
45+
assert not interpreter.stopping_condition(qml.Hadamard(0))
4646

4747
@pytest.mark.unit
4848
def test_callable_gate_set_not_supported(self):

tests/decomposition/test_decomposition_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def test_decomposition_not_found(self, _):
256256

257257
op = qml.Hadamard(wires=[0])
258258
graph = DecompositionGraph(operations=[op], gate_set={"RX", "RY", "GlobalPhase"})
259-
with pytest.raises(DecompositionError, match="Decomposition not found for {'Hadamard'}"):
259+
with pytest.warns(UserWarning, match="unable to find a decomposition for {'Hadamard'}"):
260260
graph.solve()
261261

262262
def test_lazy_solve(self, _):

0 commit comments

Comments
 (0)