Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,9 @@

<h3>Breaking changes 💔</h3>

* `MidMeasureMP` now inherits from `Operator` instead of `MeasurementProcess`.
[(#8166)](https://github.com/PennyLaneAI/pennylane/pull/8166)

* `DefaultQubit.eval_jaxpr` does not use `self.shots` from device anymore; instead, it takes `shots` as a keyword argument,
and the qnode primitive should process the `shots` and call `eval_jaxpr` accordingly.
[(#8161)](https://github.com/PennyLaneAI/pennylane/pull/8161)
Expand Down
3 changes: 0 additions & 3 deletions pennylane/drawer/drawable_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ def drawable_layers(operations, wire_map=None, bit_map=None):

# loop over operations
for op in operations:
if isinstance(op, MidMeasureMP):
if len(op.wires) > 1:
raise ValueError("Cannot draw mid-circuit measurements with more than one wire.")

if isinstance(op, MeasurementProcess) and op.mv is not None:
# Only terminal measurements that collect mid-circuit measurement statistics have
Expand Down
56 changes: 22 additions & 34 deletions pennylane/ftqc/parametric_midmeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,18 +374,18 @@ def __init__(
):
self.batch_size = None
super().__init__(wires=Wires(wires), reset=reset, postselect=postselect, id=id)
self.plane = plane
self.angle = angle
self.hyperparameters["plane"] = plane
self.hyperparameters["angle"] = angle

def _flatten(self):
metadata = (
("angle", self.angle),
("wires", self.raw_wires),
("plane", self.plane),
("reset", self.reset),
("id", self.id),
)
return (None, None), metadata
@property
def plane(self) -> str | None:
"""The plane the measurement basis lies in. Options are "XY", "ZX" and "YZ"""
return self.hyperparameters["plane"]

@property
def angle(self):
"""The angle in radians"""
return self.hyperparameters["angle"]

@property
def hash(self):
Expand Down Expand Up @@ -416,19 +416,14 @@ def _primitive_bind_call(
cls, angle=0.0, wires=None, plane="ZX", reset=False, postselect=None, id=None
):
wires = () if wires is None else wires
return cls._wires_primitive.bind(
return cls._primitive.bind(
*wires, angle=angle, plane=plane, reset=reset, postselect=postselect, id=id
)

def __repr__(self):
"""Representation of this class."""
return f"{self._shortname}_{self.plane.lower()}(wires={self.wires.tolist()}, angle={self.angle})"

@property
def has_diagonalizing_gates(self):
"""Whether there are gates that need to be applied to diagonalize the measurement"""
return True

def diagonalizing_gates(self):
"""Decompose to a diagonalizing gate and a standard MCM in the computational basis"""
if self.plane == "XY":
Expand Down Expand Up @@ -479,6 +474,10 @@ class XMidMeasureMP(ParametricMidMeasureMP):

_shortname = "measure_x"

def _flatten(self):
metadata = (("reset", self.reset), ("postselect", self.postselect), ("id", self.id))
return (), (self.wires, metadata)

def __init__(
self,
wires: Wires | None,
Expand All @@ -490,14 +489,6 @@ def __init__(
wires=Wires(wires), angle=0, plane="XY", reset=reset, postselect=postselect, id=id
)

def _flatten(self):
metadata = (
("wires", self.raw_wires),
("reset", self.reset),
("id", self.id),
)
return (None, None), metadata

def __repr__(self):
"""Representation of this class."""
return f"{self._shortname}(wires={self.wires.tolist()})"
Expand Down Expand Up @@ -537,6 +528,10 @@ class YMidMeasureMP(ParametricMidMeasureMP):

_shortname = "measure_y"

def _flatten(self):
metadata = (("reset", self.reset), ("postselect", self.postselect), ("id", self.id))
return (), (self.wires, metadata)

def __init__(
self,
wires: Wires | None,
Expand All @@ -553,14 +548,6 @@ def __init__(
id=id,
)

def _flatten(self):
metadata = (
("wires", self.raw_wires),
("reset", self.reset),
("id", self.id),
)
return (None, None), metadata

def __repr__(self):
"""Representation of this class."""
return f"{self._shortname}(wires={self.wires.tolist()})"
Expand Down Expand Up @@ -639,9 +626,10 @@ def diagonalize_mcms(tape):

.. code-block:: python3

from pennylane.ftqc import diagonalize_mcms, ParametricMidMeasureMP
from functools import partial

from pennylane.ftqc import ParametricMidMeasureMP, diagonalize_mcms

dev = qml.device("default.qubit")

@diagonalize_mcms
Expand Down
66 changes: 21 additions & 45 deletions pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

from pennylane.capture import enabled as capture_enabled
from pennylane.exceptions import QuantumFunctionError
from pennylane.operation import Operator
from pennylane.wires import Wires

from .measurement_value import MeasurementValue
from .measurements import MeasurementProcess


def _measure_impl(
Expand Down Expand Up @@ -117,7 +117,7 @@ def find_post_processed_mcms(circuit):
return post_processed_mcms


class MidMeasureMP(MeasurementProcess):
class MidMeasureMP(Operator):
"""Mid-circuit measurement.

This class additionally stores information about unknown measurement outcomes in the qubit model.
Expand All @@ -135,11 +135,9 @@ class MidMeasureMP(MeasurementProcess):
id (str): Custom label given to a measurement instance.
"""

_shortname = "measure"

def _flatten(self):
metadata = (("wires", self.raw_wires), ("reset", self.reset), ("id", self.id))
return (None, None), metadata
num_wires = 1
num_params = 0
batch_size = None

def __init__(
self,
Expand All @@ -148,26 +146,27 @@ def __init__(
postselect: int | None = None,
id: str | None = None,
):
self.batch_size = None
super().__init__(wires=Wires(wires), id=id)
self.reset = reset
self.postselect = postselect
self._hyperparameters = {"reset": reset, "postselect": postselect, "id": id}

@property
def reset(self) -> bool | None:
"""Whether to reset the wire into the zero state after the measurement."""
return self.hyperparameters["reset"]

@property
def postselect(self) -> int | None:
"""Which basis state to postselect after a mid-circuit measurement."""
return self.hyperparameters["postselect"]

# pylint: disable=arguments-renamed, arguments-differ
@classmethod
def _primitive_bind_call(cls, wires=None, reset=False, postselect=None, id=None):
wires = () if wires is None else wires
return cls._wires_primitive.bind(*wires, reset=reset, postselect=postselect, id=id)
def _primitive_bind_call(cls, *args, **kwargs):
return type.__call__(cls, *args, **kwargs)

@classmethod
def _abstract_eval(
cls,
n_wires: int | None = None,
has_eigvals=False,
shots: int | None = None,
num_device_wires: int = 0,
) -> tuple:
return (), int
@staticmethod
def compute_diagonalizing_gates(*params, wires, **hyperparams) -> list[Operator]:
return []

def label(self, decimals=None, base_label=None, cache=None): # pylint: disable=unused-argument
r"""How the mid-circuit measurement is represented in diagrams and drawings.
Expand All @@ -191,14 +190,6 @@ def label(self, decimals=None, base_label=None, cache=None): # pylint: disable=

return _label

@property
def samples_computational_basis(self):
return False

@property
def _queue_category(self):
return "_ops"

@property
def hash(self):
"""int: Returns an integer hash uniquely representing the measurement process"""
Expand All @@ -210,21 +201,6 @@ def hash(self):

return hash(fingerprint)

@property
def data(self):
"""The data of the measurement. Needed to match the Operator API."""
return []

@property
def name(self):
"""The name of the measurement. Needed to match the Operator API."""
return self.__class__.__name__

@property
def num_params(self):
"""The number of parameters. Needed to match the Operator API."""
return 0


def measure(
wires: Hashable | Wires, reset: bool = False, postselect: int | None = None
Expand Down
4 changes: 3 additions & 1 deletion pennylane/ops/functions/assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def assert_valid(
skip_new_decomp=False,
skip_pickle=False,
skip_wire_mapping=False,
skip_capture=False,
heuristic_resources=False,
) -> None:
"""Runs basic validation checks on an :class:`~.operation.Operator` to make
Expand Down Expand Up @@ -528,4 +529,5 @@ def __init__(self, wires):
_check_generator(op)
if not skip_differentiation:
_check_differentiation(op)
_check_capture(op)
if not skip_capture:
_check_capture(op)
27 changes: 0 additions & 27 deletions tests/capture/test_measurements_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ClassicalShadowMP,
DensityMatrixMP,
ExpectationMP,
MidMeasureMP,
MutualInfoMP,
ProbabilityMP,
PurityMP,
Expand Down Expand Up @@ -146,7 +145,6 @@ def f():
lambda: qml.purity(wires=(0, 1)),
lambda: qml.mutual_info(wires0=(1, 3), wires1=(2, 4), log_base=2),
lambda: qml.classical_shadow(wires=(0, 1), seed=84),
lambda: MidMeasureMP(qml.wires.Wires((0, 1))),
]


Expand All @@ -162,31 +160,6 @@ def test_capture_and_eval(func):
qml.assert_equal(mp, out)


def test_mid_measure():
"""Test that mid circuit measurements can be captured and executed."""

def f(w):
return MidMeasureMP(qml.wires.Wires((w,)), reset=True, postselect=1)

jaxpr = jax.make_jaxpr(f)(2)

assert len(jaxpr.eqns) == 1
assert jaxpr.eqns[0].primitive == MidMeasureMP._wires_primitive
assert jaxpr.eqns[0].params == {"reset": True, "postselect": 1, "id": None}
mp = jaxpr.eqns[0].outvars[0].aval
assert isinstance(mp, AbstractMeasurement)
assert mp.n_wires == 1
assert mp._abstract_eval == MidMeasureMP._abstract_eval

shapes = _get_shapes_for(*jaxpr.out_avals, shots=qml.measurements.Shots(1))
assert shapes[0] == jax.core.ShapedArray(
(), jax.numpy.int64 if jax.config.jax_enable_x64 else jax.numpy.int32
)

mp = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)[0]
assert mp == f(1)


@pytest.mark.parametrize("state_wires, shape", [(None, 16), (qml.wires.Wires((0, 1, 2, 3, 4)), 32)])
def test_state(state_wires, shape):
"""Test the capture of a state measurement."""
Expand Down
22 changes: 7 additions & 15 deletions tests/drawer/test_drawable_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,12 @@ def teleport(state):
assert layers == [ops[:2]] + [[op] for op in ops[2:]]


class TestMidMeasure:
"""Tests the various changes from mid-circuit measurements."""
def test_basic_mid_measure():
"""Tests a simple case with mid-circuit measurement."""
with AnnotatedQueue() as q:
m0 = qml.measure(0)
qml.cond(m0, qml.PauliX)(1)

def test_basic_mid_measure(self):
"""Tests a simple case with mid-circuit measurement."""
with AnnotatedQueue() as q:
m0 = qml.measure(0)
qml.cond(m0, qml.PauliX)(1)
bit_map = {q.queue[0]: None}

bit_map = {q.queue[0]: None}

assert drawable_layers(q.queue, bit_map=bit_map) == [[q.queue[0]], [q.queue[1]]]

def test_cannot_draw_multi_wire_MidMeasureMP(self):
"""Tests that MidMeasureMP is only supported with one wire."""
with pytest.raises(ValueError, match="mid-circuit measurements with more than one wire."):
drawable_layers([MidMeasureMP([0, 1])])
assert drawable_layers(q.queue, bit_map=bit_map) == [[q.queue[0]], [q.queue[1]]]
8 changes: 4 additions & 4 deletions tests/measurements/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ def test_measurement_value_eigvals(self):
are correct if the internal observable is a
MeasurementValue."""
m0 = qml.measure(0)
m0.measurements[0].id = "abc"
m0.measurements[0]._id = "abc" # pylint: disable=protected-access
m1 = qml.measure(1)
m1.measurements[0].id = "def"
m1.measurements[0]._id = "def" # pylint: disable=protected-access

mp1 = qml.sample(op=[m0, m1])
assert np.all(mp1.eigvals() == [0, 1, 2, 3])
Expand Down Expand Up @@ -341,8 +341,8 @@ def test_measurement_value_map_wires(self):
m1 = qml.measure("b")
m2 = qml.measure(0)
m3 = qml.measure(1)
m2.measurements[0].id = m0.measurements[0].id
m3.measurements[0].id = m1.measurements[0].id
m2.measurements[0]._id = m0.measurements[0].id # pylint: disable=protected-access
m3.measurements[0]._id = m1.measurements[0].id # pylint: disable=protected-access

wire_map = {"a": 0, "b": 1}

Expand Down
6 changes: 0 additions & 6 deletions tests/measurements/test_mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
# pylint: disable=too-few-public-methods, too-many-public-methods


def test_samples_computational_basis():
"""Test that samples_computational_basis is always false for mid circuit measurements."""
m = MidMeasureMP(Wires(0))
assert not m.samples_computational_basis


class TestMeasure:
"""Tests for the measure function"""

Expand Down
1 change: 1 addition & 0 deletions tests/ops/functions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _trotterize_qfunc_dummy(time, theta, phi, wires, flip=False):


_INSTANCES_TO_TEST = [
(qml.measurements.MidMeasureMP(wires=0), {"skip_capture": True}),
(ChangeOpBasis(qml.PauliX(0), qml.PauliZ(0)), {}),
(qml.sum(qml.PauliX(0), qml.PauliZ(0)), {}),
(qml.sum(qml.X(0), qml.X(0), qml.Z(0), qml.Z(0)), {}),
Expand Down
Loading