5555from .utils import translate_op_alias
5656
5757
58- class DecompositionGraph : # pylint: disable=too-many-instance-attributes
58+ @dataclass (frozen = True )
59+ class _DecompositionNode :
60+ """A node that represents a decomposition rule."""
61+
62+ rule : DecompositionRule
63+ decomp_resource : Resources
64+
65+ def count (self , op : CompressedResourceOp ):
66+ """Find the number of occurrences of an operator in the decomposition."""
67+ return self .decomp_resource .gate_counts .get (op , 0 )
68+
69+
70+ class DecompositionGraph : # pylint: disable=too-many-instance-attributes,too-few-public-methods
5971 """A graph that models a decomposition problem.
6072
6173 The decomposition graph contains two types of nodes: operator nodes and decomposition nodes.
@@ -117,18 +129,18 @@ def my_cz(wires):
117129 operations=[op],
118130 gate_set={"RZ", "RX", "CNOT", "GlobalPhase"},
119131 )
120- graph.solve()
132+ solution = graph.solve()
121133
122134 >>> with qml.queuing.AnnotatedQueue() as q:
123- ... graph .decomposition(op)(0.5, wires=[0, 1])
135+ ... solution .decomposition(op)(0.5, wires=[0, 1])
124136 >>> q.queue
125137 [RZ(1.5707963267948966, wires=[1]),
126138 RY(0.25, wires=[1]),
127139 CNOT(wires=[0, 1]),
128140 RY(-0.25, wires=[1]),
129141 CNOT(wires=[0, 1]),
130142 RZ(-1.5707963267948966, wires=[1])]
131- >>> graph .resource_estimate(op)
143+ >>> solution .resource_estimate(op)
132144 <num_gates=10, gate_counts={RZ: 6, CNOT: 2, RX: 2}, weighted_cost=10.0>
133145
134146 """
@@ -159,53 +171,11 @@ def __init__(
159171
160172 # Initializes the graph.
161173 self ._graph = rx .PyDiGraph ()
162- self ._visitor = None
163174
164175 # Construct the decomposition graph
165176 self ._start = self ._graph .add_node (None )
166177 self ._construct_graph (operations )
167178
168- def _get_decompositions (self , op_node : CompressedResourceOp ) -> list [DecompositionRule ]:
169- """Helper function to get a list of decomposition rules."""
170-
171- op_name = _to_name (op_node )
172-
173- if op_name in self ._fixed_decomps :
174- return [self ._fixed_decomps [op_name ]]
175-
176- decomps = self ._alt_decomps .get (op_name , []) + list_decomps (op_name )
177-
178- if (
179- issubclass (op_node .op_type , qml .ops .Adjoint )
180- and self_adjoint not in decomps
181- and adjoint_rotation not in decomps
182- ):
183- # In general, we decompose the adjoint of an operator by applying adjoint to the
184- # decompositions of the operator. However, this is not necessary if the operator
185- # is self-adjoint or if it has a single rotation angle which can be trivially
186- # inverted to obtain its adjoint. In this case, `self_adjoint` or `adjoint_rotation`
187- # would've already been retrieved as a potential decomposition rule for this
188- # operator, so there is no need to consider the general case.
189- decomps .extend (self ._get_adjoint_decompositions (op_node ))
190-
191- elif (
192- issubclass (op_node .op_type , qml .ops .Pow )
193- and pow_rotation not in decomps
194- and pow_involutory not in decomps
195- ):
196- # Similar to the adjoint case, the `_get_pow_decompositions` contains the general
197- # approach we take to decompose powers of operators. However, if the operator is
198- # involutory or if it has a single rotation angle that can be trivially multiplied
199- # with the power, we would've already retrieved `pow_involutory` or `pow_rotation`
200- # as a potential decomposition rule for this operator, so there is no need to consider
201- # the general case.
202- decomps .extend (self ._get_pow_decompositions (op_node ))
203-
204- elif op_node .op_type in (qml .ops .Controlled , qml .ops .ControlledOp ):
205- decomps .extend (self ._get_controlled_decompositions (op_node ))
206-
207- return decomps
208-
209179 def _construct_graph (self , operations ):
210180 """Constructs the decomposition graph."""
211181 for op in operations :
@@ -254,6 +224,47 @@ def _add_decomp(self, rule: DecompositionRule, op_node: CompressedResourceOp, op
254224 self ._graph .add_edge (op_node_idx , d_node_idx , (op_node_idx , d_node_idx ))
255225 self ._graph .add_edge (d_node_idx , op_idx , 0 )
256226
227+ def _get_decompositions (self , op_node : CompressedResourceOp ) -> list [DecompositionRule ]:
228+ """Helper function to get a list of decomposition rules."""
229+
230+ op_name = _to_name (op_node )
231+
232+ if op_name in self ._fixed_decomps :
233+ return [self ._fixed_decomps [op_name ]]
234+
235+ decomps = self ._alt_decomps .get (op_name , []) + list_decomps (op_name )
236+
237+ if (
238+ issubclass (op_node .op_type , qml .ops .Adjoint )
239+ and self_adjoint not in decomps
240+ and adjoint_rotation not in decomps
241+ ):
242+ # In general, we decompose the adjoint of an operator by applying adjoint to the
243+ # decompositions of the operator. However, this is not necessary if the operator
244+ # is self-adjoint or if it has a single rotation angle which can be trivially
245+ # inverted to obtain its adjoint. In this case, `self_adjoint` or `adjoint_rotation`
246+ # would've already been retrieved as a potential decomposition rule for this
247+ # operator, so there is no need to consider the general case.
248+ decomps .extend (self ._get_adjoint_decompositions (op_node ))
249+
250+ elif (
251+ issubclass (op_node .op_type , qml .ops .Pow )
252+ and pow_rotation not in decomps
253+ and pow_involutory not in decomps
254+ ):
255+ # Similar to the adjoint case, the `_get_pow_decompositions` contains the general
256+ # approach we take to decompose powers of operators. However, if the operator is
257+ # involutory or if it has a single rotation angle that can be trivially multiplied
258+ # with the power, we would've already retrieved `pow_involutory` or `pow_rotation`
259+ # as a potential decomposition rule for this operator, so there is no need to consider
260+ # the general case.
261+ decomps .extend (self ._get_pow_decompositions (op_node ))
262+
263+ elif op_node .op_type in (qml .ops .Controlled , qml .ops .ControlledOp ):
264+ decomps .extend (self ._get_controlled_decompositions (op_node ))
265+
266+ return decomps
267+
257268 def _get_adjoint_decompositions (self , op_node : CompressedResourceOp ) -> list [DecompositionRule ]:
258269 """Gets the decomposition rules for the adjoint of an operator."""
259270
@@ -315,16 +326,19 @@ def _get_controlled_decompositions(
315326
316327 return rules
317328
318- def solve (self , lazy = True ):
329+ def solve (self , lazy = True ) -> DecompGraphSolution :
319330 """Solves the graph using the Dijkstra search algorithm.
320331
321332 Args:
322333 lazy (bool): If True, the Dijkstra search will stop once optimal decompositions are
323334 found for all operations that the graph was initialized with. Otherwise, the
324335 entire graph will be explored.
325336
337+ Returns:
338+ DecompGraphSolution
339+
326340 """
327- self . _visitor = _DecompositionSearchVisitor (
341+ visitor = DecompositionSearchVisitor (
328342 self ._graph ,
329343 self ._weights ,
330344 self ._original_ops_indices ,
@@ -333,15 +347,29 @@ def solve(self, lazy=True):
333347 rx .dijkstra_search (
334348 self ._graph ,
335349 source = [self ._start ],
336- weight_fn = self . _visitor .edge_weight ,
337- visitor = self . _visitor ,
350+ weight_fn = visitor .edge_weight ,
351+ visitor = visitor ,
338352 )
339- if self . _visitor .unsolved_op_indices :
340- unsolved_ops = [self ._graph [op_idx ] for op_idx in self . _visitor .unsolved_op_indices ]
353+ if visitor .unsolved_op_indices :
354+ unsolved_ops = [self ._graph [op_idx ] for op_idx in visitor .unsolved_op_indices ]
341355 op_names = {op .name for op in unsolved_ops }
342356 raise DecompositionError (
343357 f"Decomposition not found for { op_names } to the gate set { set (self ._weights )} "
344358 )
359+ return DecompGraphSolution (visitor , self ._all_op_indices )
360+
361+
362+ class DecompGraphSolution :
363+ """A solution to a decomposition graph."""
364+
365+ def __init__ (
366+ self ,
367+ visitor : DecompositionSearchVisitor ,
368+ all_op_indices : dict [CompressedResourceOp , int ],
369+ ) -> None :
370+ self ._visitor = visitor
371+ self ._graph = visitor ._graph # pylint: disable=protected-access
372+ self ._all_op_indices = all_op_indices
345373
346374 def is_solved_for (self , op ):
347375 """Tests whether the decomposition graph is solved for a given operator."""
@@ -372,10 +400,10 @@ def resource_estimate(self, op) -> Resources:
372400 operations=[op],
373401 gate_set={"RZ", "RX", "CNOT", "GlobalPhase"},
374402 )
375- graph.solve()
403+ solution = graph.solve()
376404
377405 >>> with qml.queuing.AnnotatedQueue() as q:
378- ... graph .decomposition(op)(0.5, wires=[0, 1])
406+ ... solution .decomposition(op)(0.5, wires=[0, 1])
379407 >>> q.queue
380408 [RZ(1.5707963267948966, wires=[1]),
381409 RY(0.25, wires=[1]),
@@ -415,8 +443,8 @@ def decomposition(self, op: Operator) -> DecompositionRule:
415443 operations=[op],
416444 gate_set={"RZ", "RX", "CNOT", "GlobalPhase"},
417445 )
418- graph.solve()
419- rule = graph .decomposition(op)
446+ solution = graph.solve()
447+ rule = solution .decomposition(op)
420448
421449 >>> with qml.queuing.AnnotatedQueue() as q:
422450 ... rule(*op.parameters, wires=op.wires, **op.hyperparameters)
@@ -436,7 +464,7 @@ def decomposition(self, op: Operator) -> DecompositionRule:
436464 return self ._graph [d_node_idx ].rule
437465
438466
439- class _DecompositionSearchVisitor (DijkstraVisitor ):
467+ class DecompositionSearchVisitor (DijkstraVisitor ):
440468 """The visitor used in the Dijkstra search for the optimal decomposition."""
441469
442470 def __init__ (
@@ -505,18 +533,6 @@ def edge_relaxed(self, edge):
505533 self .distances [target_idx ] = self .distances [src_idx ]
506534
507535
508- @dataclass (frozen = True )
509- class _DecompositionNode :
510- """A node that represents a decomposition rule."""
511-
512- rule : DecompositionRule
513- decomp_resource : Resources
514-
515- def count (self , op : CompressedResourceOp ):
516- """Find the number of occurrences of an operator in the decomposition."""
517- return self .decomp_resource .gate_counts .get (op , 0 )
518-
519-
520536def _to_name (op ):
521537 if isinstance (op , type ):
522538 return op .__name__
0 commit comments