Skip to content

Commit 8630f5b

Browse files
authored
Refactor DecompositionGraph to two classes (#8031)
**Context:** **Description of the Change:** Splitting the solution part of a `DecompositionGraph` to a separate `DecompGraphSolution` class, returned by the `solve` method. **Benefits:** Cleaner code **Possible Drawbacks:** **Related GitHub Issues:** [sc-97093]
1 parent 22febc8 commit 8630f5b

File tree

6 files changed

+181
-157
lines changed

6 files changed

+181
-157
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,11 @@
203203
<h4>Resource-efficient decompositions 🔎</h4>
204204

205205
* With :func:`~.decomposition.enable_graph()`, dynamically allocated wires are now supported in decomposition rules. This provides a smoother overall experience when decomposing operators in a way that requires auxiliary/work wires.
206-
207206
[(#7861)](https://github.com/PennyLaneAI/pennylane/pull/7861)
207+
208+
* A :class:`~.decomposition.decomposition_graph.DecompGraphSolution` class is added to store the solution of a decomposition graph. An instance of this class is returned from the `solve` method of the :class:`~.decomposition.decomposition_graph.DecompositionGraph`.
209+
[(#8031)](https://github.com/PennyLaneAI/pennylane/pull/8031)
210+
208211
<h3>Labs: a place for unified and rapid prototyping of research software 🧪</h3>
209212

210213
* Added state of the art resources for the `ResourceSelectPauliRot` template and the

pennylane/decomposition/decomposition_graph.py

Lines changed: 85 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,19 @@
5555
from .utils import translate_op_alias
5656

5757

58-
class DecompositionGraph: # pylint: disable=too-many-instance-attributes
58+
@dataclass(frozen=True)
59+
class _DecompositionNode:
60+
"""A node that represents a decomposition rule."""
61+
62+
rule: DecompositionRule
63+
decomp_resource: Resources
64+
65+
def count(self, op: CompressedResourceOp):
66+
"""Find the number of occurrences of an operator in the decomposition."""
67+
return self.decomp_resource.gate_counts.get(op, 0)
68+
69+
70+
class DecompositionGraph: # pylint: disable=too-many-instance-attributes,too-few-public-methods
5971
"""A graph that models a decomposition problem.
6072
6173
The decomposition graph contains two types of nodes: operator nodes and decomposition nodes.
@@ -117,18 +129,18 @@ def my_cz(wires):
117129
operations=[op],
118130
gate_set={"RZ", "RX", "CNOT", "GlobalPhase"},
119131
)
120-
graph.solve()
132+
solution = graph.solve()
121133
122134
>>> with qml.queuing.AnnotatedQueue() as q:
123-
... graph.decomposition(op)(0.5, wires=[0, 1])
135+
... solution.decomposition(op)(0.5, wires=[0, 1])
124136
>>> q.queue
125137
[RZ(1.5707963267948966, wires=[1]),
126138
RY(0.25, wires=[1]),
127139
CNOT(wires=[0, 1]),
128140
RY(-0.25, wires=[1]),
129141
CNOT(wires=[0, 1]),
130142
RZ(-1.5707963267948966, wires=[1])]
131-
>>> graph.resource_estimate(op)
143+
>>> solution.resource_estimate(op)
132144
<num_gates=10, gate_counts={RZ: 6, CNOT: 2, RX: 2}, weighted_cost=10.0>
133145
134146
"""
@@ -159,53 +171,11 @@ def __init__(
159171

160172
# Initializes the graph.
161173
self._graph = rx.PyDiGraph()
162-
self._visitor = None
163174

164175
# Construct the decomposition graph
165176
self._start = self._graph.add_node(None)
166177
self._construct_graph(operations)
167178

168-
def _get_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]:
169-
"""Helper function to get a list of decomposition rules."""
170-
171-
op_name = _to_name(op_node)
172-
173-
if op_name in self._fixed_decomps:
174-
return [self._fixed_decomps[op_name]]
175-
176-
decomps = self._alt_decomps.get(op_name, []) + list_decomps(op_name)
177-
178-
if (
179-
issubclass(op_node.op_type, qml.ops.Adjoint)
180-
and self_adjoint not in decomps
181-
and adjoint_rotation not in decomps
182-
):
183-
# In general, we decompose the adjoint of an operator by applying adjoint to the
184-
# decompositions of the operator. However, this is not necessary if the operator
185-
# is self-adjoint or if it has a single rotation angle which can be trivially
186-
# inverted to obtain its adjoint. In this case, `self_adjoint` or `adjoint_rotation`
187-
# would've already been retrieved as a potential decomposition rule for this
188-
# operator, so there is no need to consider the general case.
189-
decomps.extend(self._get_adjoint_decompositions(op_node))
190-
191-
elif (
192-
issubclass(op_node.op_type, qml.ops.Pow)
193-
and pow_rotation not in decomps
194-
and pow_involutory not in decomps
195-
):
196-
# Similar to the adjoint case, the `_get_pow_decompositions` contains the general
197-
# approach we take to decompose powers of operators. However, if the operator is
198-
# involutory or if it has a single rotation angle that can be trivially multiplied
199-
# with the power, we would've already retrieved `pow_involutory` or `pow_rotation`
200-
# as a potential decomposition rule for this operator, so there is no need to consider
201-
# the general case.
202-
decomps.extend(self._get_pow_decompositions(op_node))
203-
204-
elif op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp):
205-
decomps.extend(self._get_controlled_decompositions(op_node))
206-
207-
return decomps
208-
209179
def _construct_graph(self, operations):
210180
"""Constructs the decomposition graph."""
211181
for op in operations:
@@ -254,6 +224,47 @@ def _add_decomp(self, rule: DecompositionRule, op_node: CompressedResourceOp, op
254224
self._graph.add_edge(op_node_idx, d_node_idx, (op_node_idx, d_node_idx))
255225
self._graph.add_edge(d_node_idx, op_idx, 0)
256226

227+
def _get_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]:
228+
"""Helper function to get a list of decomposition rules."""
229+
230+
op_name = _to_name(op_node)
231+
232+
if op_name in self._fixed_decomps:
233+
return [self._fixed_decomps[op_name]]
234+
235+
decomps = self._alt_decomps.get(op_name, []) + list_decomps(op_name)
236+
237+
if (
238+
issubclass(op_node.op_type, qml.ops.Adjoint)
239+
and self_adjoint not in decomps
240+
and adjoint_rotation not in decomps
241+
):
242+
# In general, we decompose the adjoint of an operator by applying adjoint to the
243+
# decompositions of the operator. However, this is not necessary if the operator
244+
# is self-adjoint or if it has a single rotation angle which can be trivially
245+
# inverted to obtain its adjoint. In this case, `self_adjoint` or `adjoint_rotation`
246+
# would've already been retrieved as a potential decomposition rule for this
247+
# operator, so there is no need to consider the general case.
248+
decomps.extend(self._get_adjoint_decompositions(op_node))
249+
250+
elif (
251+
issubclass(op_node.op_type, qml.ops.Pow)
252+
and pow_rotation not in decomps
253+
and pow_involutory not in decomps
254+
):
255+
# Similar to the adjoint case, the `_get_pow_decompositions` contains the general
256+
# approach we take to decompose powers of operators. However, if the operator is
257+
# involutory or if it has a single rotation angle that can be trivially multiplied
258+
# with the power, we would've already retrieved `pow_involutory` or `pow_rotation`
259+
# as a potential decomposition rule for this operator, so there is no need to consider
260+
# the general case.
261+
decomps.extend(self._get_pow_decompositions(op_node))
262+
263+
elif op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp):
264+
decomps.extend(self._get_controlled_decompositions(op_node))
265+
266+
return decomps
267+
257268
def _get_adjoint_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]:
258269
"""Gets the decomposition rules for the adjoint of an operator."""
259270

@@ -315,16 +326,19 @@ def _get_controlled_decompositions(
315326

316327
return rules
317328

318-
def solve(self, lazy=True):
329+
def solve(self, lazy=True) -> DecompGraphSolution:
319330
"""Solves the graph using the Dijkstra search algorithm.
320331
321332
Args:
322333
lazy (bool): If True, the Dijkstra search will stop once optimal decompositions are
323334
found for all operations that the graph was initialized with. Otherwise, the
324335
entire graph will be explored.
325336
337+
Returns:
338+
DecompGraphSolution
339+
326340
"""
327-
self._visitor = _DecompositionSearchVisitor(
341+
visitor = DecompositionSearchVisitor(
328342
self._graph,
329343
self._weights,
330344
self._original_ops_indices,
@@ -333,15 +347,29 @@ def solve(self, lazy=True):
333347
rx.dijkstra_search(
334348
self._graph,
335349
source=[self._start],
336-
weight_fn=self._visitor.edge_weight,
337-
visitor=self._visitor,
350+
weight_fn=visitor.edge_weight,
351+
visitor=visitor,
338352
)
339-
if self._visitor.unsolved_op_indices:
340-
unsolved_ops = [self._graph[op_idx] for op_idx in self._visitor.unsolved_op_indices]
353+
if visitor.unsolved_op_indices:
354+
unsolved_ops = [self._graph[op_idx] for op_idx in visitor.unsolved_op_indices]
341355
op_names = {op.name for op in unsolved_ops}
342356
raise DecompositionError(
343357
f"Decomposition not found for {op_names} to the gate set {set(self._weights)}"
344358
)
359+
return DecompGraphSolution(visitor, self._all_op_indices)
360+
361+
362+
class DecompGraphSolution:
363+
"""A solution to a decomposition graph."""
364+
365+
def __init__(
366+
self,
367+
visitor: DecompositionSearchVisitor,
368+
all_op_indices: dict[CompressedResourceOp, int],
369+
) -> None:
370+
self._visitor = visitor
371+
self._graph = visitor._graph # pylint: disable=protected-access
372+
self._all_op_indices = all_op_indices
345373

346374
def is_solved_for(self, op):
347375
"""Tests whether the decomposition graph is solved for a given operator."""
@@ -372,10 +400,10 @@ def resource_estimate(self, op) -> Resources:
372400
operations=[op],
373401
gate_set={"RZ", "RX", "CNOT", "GlobalPhase"},
374402
)
375-
graph.solve()
403+
solution = graph.solve()
376404
377405
>>> with qml.queuing.AnnotatedQueue() as q:
378-
... graph.decomposition(op)(0.5, wires=[0, 1])
406+
... solution.decomposition(op)(0.5, wires=[0, 1])
379407
>>> q.queue
380408
[RZ(1.5707963267948966, wires=[1]),
381409
RY(0.25, wires=[1]),
@@ -415,8 +443,8 @@ def decomposition(self, op: Operator) -> DecompositionRule:
415443
operations=[op],
416444
gate_set={"RZ", "RX", "CNOT", "GlobalPhase"},
417445
)
418-
graph.solve()
419-
rule = graph.decomposition(op)
446+
solution = graph.solve()
447+
rule = solution.decomposition(op)
420448
421449
>>> with qml.queuing.AnnotatedQueue() as q:
422450
... rule(*op.parameters, wires=op.wires, **op.hyperparameters)
@@ -436,7 +464,7 @@ def decomposition(self, op: Operator) -> DecompositionRule:
436464
return self._graph[d_node_idx].rule
437465

438466

439-
class _DecompositionSearchVisitor(DijkstraVisitor):
467+
class DecompositionSearchVisitor(DijkstraVisitor):
440468
"""The visitor used in the Dijkstra search for the optimal decomposition."""
441469

442470
def __init__(
@@ -505,18 +533,6 @@ def edge_relaxed(self, edge):
505533
self.distances[target_idx] = self.distances[src_idx]
506534

507535

508-
@dataclass(frozen=True)
509-
class _DecompositionNode:
510-
"""A node that represents a decomposition rule."""
511-
512-
rule: DecompositionRule
513-
decomp_resource: Resources
514-
515-
def count(self, op: CompressedResourceOp):
516-
"""Find the number of occurrences of an operator in the decomposition."""
517-
return self.decomp_resource.gate_counts.get(op, 0)
518-
519-
520536
def _to_name(op):
521537
if isinstance(op, type):
522538
return op.__name__

0 commit comments

Comments
 (0)