Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 81 additions & 64 deletions pennylane/decomposition/decomposition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,19 @@
from .utils import translate_op_alias


class DecompositionGraph: # pylint: disable=too-many-instance-attributes
@dataclass(frozen=True)
class _DecompositionNode:
"""A node that represents a decomposition rule."""

rule: DecompositionRule
decomp_resource: Resources

def count(self, op: CompressedResourceOp):
"""Find the number of occurrences of an operator in the decomposition."""
return self.decomp_resource.gate_counts.get(op, 0)


class DecompositionGraph: # pylint: disable=too-many-instance-attributes,too-few-public-methods
"""A graph that models a decomposition problem.

The decomposition graph contains two types of nodes: operator nodes and decomposition nodes.
Expand Down Expand Up @@ -117,18 +129,18 @@ def my_cz(wires):
operations=[op],
gate_set={"RZ", "RX", "CNOT", "GlobalPhase"},
)
graph.solve()
solution = graph.solve()

>>> with qml.queuing.AnnotatedQueue() as q:
... graph.decomposition(op)(0.5, wires=[0, 1])
... solution.decomposition(op)(0.5, wires=[0, 1])
>>> q.queue
[RZ(1.5707963267948966, wires=[1]),
RY(0.25, wires=[1]),
CNOT(wires=[0, 1]),
RY(-0.25, wires=[1]),
CNOT(wires=[0, 1]),
RZ(-1.5707963267948966, wires=[1])]
>>> graph.resource_estimate(op)
>>> solution.resource_estimate(op)
<num_gates=10, gate_counts={RZ: 6, CNOT: 2, RX: 2}, weighted_cost=10.0>

"""
Expand Down Expand Up @@ -159,53 +171,11 @@ def __init__(

# Initializes the graph.
self._graph = rx.PyDiGraph()
self._visitor = None

# Construct the decomposition graph
self._start = self._graph.add_node(None)
self._construct_graph(operations)

def _get_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]:
"""Helper function to get a list of decomposition rules."""

op_name = _to_name(op_node)

if op_name in self._fixed_decomps:
return [self._fixed_decomps[op_name]]

decomps = self._alt_decomps.get(op_name, []) + list_decomps(op_name)

if (
issubclass(op_node.op_type, qml.ops.Adjoint)
and self_adjoint not in decomps
and adjoint_rotation not in decomps
):
# In general, we decompose the adjoint of an operator by applying adjoint to the
# decompositions of the operator. However, this is not necessary if the operator
# is self-adjoint or if it has a single rotation angle which can be trivially
# inverted to obtain its adjoint. In this case, `self_adjoint` or `adjoint_rotation`
# would've already been retrieved as a potential decomposition rule for this
# operator, so there is no need to consider the general case.
decomps.extend(self._get_adjoint_decompositions(op_node))

elif (
issubclass(op_node.op_type, qml.ops.Pow)
and pow_rotation not in decomps
and pow_involutory not in decomps
):
# Similar to the adjoint case, the `_get_pow_decompositions` contains the general
# approach we take to decompose powers of operators. However, if the operator is
# involutory or if it has a single rotation angle that can be trivially multiplied
# with the power, we would've already retrieved `pow_involutory` or `pow_rotation`
# as a potential decomposition rule for this operator, so there is no need to consider
# the general case.
decomps.extend(self._get_pow_decompositions(op_node))

elif op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp):
decomps.extend(self._get_controlled_decompositions(op_node))

return decomps

def _construct_graph(self, operations):
"""Constructs the decomposition graph."""
for op in operations:
Expand Down Expand Up @@ -254,6 +224,47 @@ def _add_decomp(self, rule: DecompositionRule, op_node: CompressedResourceOp, op
self._graph.add_edge(op_node_idx, d_node_idx, (op_node_idx, d_node_idx))
self._graph.add_edge(d_node_idx, op_idx, 0)

def _get_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]:
"""Helper function to get a list of decomposition rules."""

op_name = _to_name(op_node)

if op_name in self._fixed_decomps:
return [self._fixed_decomps[op_name]]

decomps = self._alt_decomps.get(op_name, []) + list_decomps(op_name)

if (
issubclass(op_node.op_type, qml.ops.Adjoint)
and self_adjoint not in decomps
and adjoint_rotation not in decomps
):
# In general, we decompose the adjoint of an operator by applying adjoint to the
# decompositions of the operator. However, this is not necessary if the operator
# is self-adjoint or if it has a single rotation angle which can be trivially
# inverted to obtain its adjoint. In this case, `self_adjoint` or `adjoint_rotation`
# would've already been retrieved as a potential decomposition rule for this
# operator, so there is no need to consider the general case.
decomps.extend(self._get_adjoint_decompositions(op_node))

elif (
issubclass(op_node.op_type, qml.ops.Pow)
and pow_rotation not in decomps
and pow_involutory not in decomps
):
# Similar to the adjoint case, the `_get_pow_decompositions` contains the general
# approach we take to decompose powers of operators. However, if the operator is
# involutory or if it has a single rotation angle that can be trivially multiplied
# with the power, we would've already retrieved `pow_involutory` or `pow_rotation`
# as a potential decomposition rule for this operator, so there is no need to consider
# the general case.
decomps.extend(self._get_pow_decompositions(op_node))

elif op_node.op_type in (qml.ops.Controlled, qml.ops.ControlledOp):
decomps.extend(self._get_controlled_decompositions(op_node))

return decomps

def _get_adjoint_decompositions(self, op_node: CompressedResourceOp) -> list[DecompositionRule]:
"""Gets the decomposition rules for the adjoint of an operator."""

Expand Down Expand Up @@ -315,16 +326,19 @@ def _get_controlled_decompositions(

return rules

def solve(self, lazy=True):
def solve(self, lazy=True) -> DecompGraphSolution:
"""Solves the graph using the Dijkstra search algorithm.

Args:
lazy (bool): If True, the Dijkstra search will stop once optimal decompositions are
found for all operations that the graph was initialized with. Otherwise, the
entire graph will be explored.

Returns:
DecompGraphSolution

"""
self._visitor = _DecompositionSearchVisitor(
visitor = _DecompositionSearchVisitor(
self._graph,
self._weights,
self._original_ops_indices,
Expand All @@ -333,15 +347,30 @@ def solve(self, lazy=True):
rx.dijkstra_search(
self._graph,
source=[self._start],
weight_fn=self._visitor.edge_weight,
visitor=self._visitor,
weight_fn=visitor.edge_weight,
visitor=visitor,
)
if self._visitor.unsolved_op_indices:
unsolved_ops = [self._graph[op_idx] for op_idx in self._visitor.unsolved_op_indices]
if visitor.unsolved_op_indices:
unsolved_ops = [self._graph[op_idx] for op_idx in visitor.unsolved_op_indices]
op_names = {op.name for op in unsolved_ops}
raise DecompositionError(
f"Decomposition not found for {op_names} to the gate set {set(self._weights)}"
)
return DecompGraphSolution(visitor, self._graph, self._all_op_indices)


class DecompGraphSolution:
"""A solution to a decomposition graph."""

def __init__(
self,
visitor: _DecompositionSearchVisitor,
graph: rx.PyDiGraph,
all_op_indices: dict[CompressedResourceOp, int],
) -> None:
self._visitor = visitor
self._graph = graph
self._all_op_indices = all_op_indices

def is_solved_for(self, op):
"""Tests whether the decomposition graph is solved for a given operator."""
Expand Down Expand Up @@ -505,18 +534,6 @@ def edge_relaxed(self, edge):
self.distances[target_idx] = self.distances[src_idx]


@dataclass(frozen=True)
class _DecompositionNode:
"""A node that represents a decomposition rule."""

rule: DecompositionRule
decomp_resource: Resources

def count(self, op: CompressedResourceOp):
"""Find the number of occurrences of an operator in the decomposition."""
return self.decomp_resource.gate_counts.get(op, 0)


def _to_name(op):
if isinstance(op, type):
return op.__name__
Expand Down
39 changes: 22 additions & 17 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import pennylane as qml
from pennylane.decomposition import DecompositionGraph
from pennylane.decomposition.decomposition_graph import DecompGraphSolution
from pennylane.decomposition.utils import translate_op_alias
from pennylane.operation import Operator
from pennylane.transforms.core import transform
Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(
"to enable the new system."
)

self._decomp_graph = None
self._decomp_graph_solution = None
self._target_gate_names = None
self._fixed_decomps, self._alt_decomps = fixed_decomps, alt_decomps

Expand Down Expand Up @@ -183,7 +184,7 @@ def decompose_operation(self, op: qml.operation.Operator):
op,
self.stopping_condition,
max_expansion=max_expansion,
decomp_graph=self._decomp_graph,
decomp_graph_solution=self._decomp_graph_solution,
)
)

Expand All @@ -198,9 +199,9 @@ def _evaluate_jaxpr_decomposition(self, op: qml.operation.Operator):
if self.max_expansion is not None and self._current_depth >= self.max_expansion:
return self.interpret_operation(op)

if qml.decomposition.enabled_graph() and self._decomp_graph.is_solved_for(op):
if qml.decomposition.enabled_graph() and self._decomp_graph_solution.is_solved_for(op):

rule = self._decomp_graph.decomposition(op)
rule = self._decomp_graph_solution.decomposition(op)
num_wires = len(op.wires)

def compute_qfunc_decomposition(*_args, **_kwargs):
Expand Down Expand Up @@ -242,7 +243,7 @@ def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list:
for const, constvar in zip(consts, jaxpr.constvars, strict=True):
self._env_map[constvar] = const

if qml.decomposition.enabled_graph() and not self._decomp_graph:
if qml.decomposition.enabled_graph() and not self._decomp_graph_solution:

with qml.capture.pause():

Expand All @@ -251,7 +252,7 @@ def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list:
operations = collector.state["ops"]

if operations:
self._decomp_graph = _construct_and_solve_decomp_graph(
self._decomp_graph_solution = _construct_and_solve_decomp_graph(
operations,
self._gate_set,
self._fixed_decomps,
Expand Down Expand Up @@ -322,7 +323,7 @@ def interpret_operation_eqn(self, eqn: jax.extend.core.JaxprEqn):
if (
op.has_qfunc_decomposition
or qml.decomposition.enabled_graph()
and self._decomp_graph.is_solved_for(op)
and self._decomp_graph_solution.is_solved_for(op)
):
return self._evaluate_jaxpr_decomposition(op)

Expand Down Expand Up @@ -778,11 +779,11 @@ def _stopping_condition(op):

# If the decomposition graph is enabled, we create a DecompositionGraph instance
# to optimize the decomposition.
decomp_graph = None
decomp_graph_solution = None

if qml.decomposition.enabled_graph():

decomp_graph = _construct_and_solve_decomp_graph(
decomp_graph_solution = _construct_and_solve_decomp_graph(
tape.operations,
gate_set,
fixed_decomps=fixed_decomps,
Expand All @@ -794,7 +795,10 @@ def _stopping_condition(op):
final_op
for op in tape.operations
for final_op in _operator_decomposition_gen(
op, _stopping_condition, max_expansion=max_expansion, decomp_graph=decomp_graph
op,
_stopping_condition,
max_expansion=max_expansion,
decomp_graph_solution=decomp_graph_solution,
)
]
except RecursionError as e:
Expand All @@ -814,7 +818,7 @@ def _operator_decomposition_gen(
acceptance_function: Callable[[qml.operation.Operator], bool],
max_expansion: int | None = None,
current_depth=0,
decomp_graph: DecompositionGraph = None,
decomp_graph_solution: DecompGraphSolution | None = None,
) -> Generator[qml.operation.Operator]:
"""A generator that yields the next operation that is accepted."""

Expand All @@ -826,8 +830,8 @@ def _operator_decomposition_gen(

if acceptance_function(op) or max_depth_reached:
yield op
elif decomp_graph is not None and decomp_graph.is_solved_for(op):
op_rule = decomp_graph.decomposition(op)
elif decomp_graph_solution is not None and decomp_graph_solution.is_solved_for(op):
op_rule = decomp_graph_solution.decomposition(op)
with qml.queuing.AnnotatedQueue() as decomposed_ops:
op_rule(*op.parameters, wires=op.wires, **op.hyperparameters)
decomp = decomposed_ops.queue
Expand All @@ -842,7 +846,7 @@ def _operator_decomposition_gen(
acceptance_function,
max_expansion=max_expansion,
current_depth=current_depth,
decomp_graph=decomp_graph,
decomp_graph_solution=decomp_graph_solution,
)


Expand Down Expand Up @@ -929,7 +933,9 @@ def _stopping_condition(op):
return gate_set, _stopping_condition


def _construct_and_solve_decomp_graph(operations, target_gates, fixed_decomps, alt_decomps):
def _construct_and_solve_decomp_graph(
operations, target_gates, fixed_decomps, alt_decomps
) -> DecompGraphSolution:
"""Create and solve a DecompositionGraph instance to optimize the decomposition."""

# Create the decomposition graph
Expand All @@ -941,5 +947,4 @@ def _construct_and_solve_decomp_graph(operations, target_gates, fixed_decomps, a
)

# Find the efficient pathways to the target gate set
decomp_graph.solve()
return decomp_graph
return decomp_graph.solve()
Loading