Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
89a90e0
E.C>
PietropaoloFrisoni Sep 3, 2025
b099c5f
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 3, 2025
4d25cff
Testing rendering doc and tests failures
PietropaoloFrisoni Sep 4, 2025
62013b8
Documentation
PietropaoloFrisoni Sep 4, 2025
ea354aa
Simple test for all interfaces
PietropaoloFrisoni Sep 4, 2025
2e7d78b
Combining tests into one
PietropaoloFrisoni Sep 4, 2025
f930a90
Extending argument to `process_sample`
PietropaoloFrisoni Sep 4, 2025
e33372b
Changelog
PietropaoloFrisoni Sep 4, 2025
5514f46
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Sep 4, 2025
1055050
Chaning name of parameter and remove duplicated changelog number
PietropaoloFrisoni Sep 5, 2025
b395c90
Removing type hints
PietropaoloFrisoni Sep 5, 2025
c874bde
Removing type hints
PietropaoloFrisoni Sep 5, 2025
79bf5ff
Test for jittability
PietropaoloFrisoni Sep 6, 2025
2041295
jittability test improved
PietropaoloFrisoni Sep 6, 2025
6e27922
Merge branch 'master' into precision_optional_parameter
PietropaoloFrisoni Sep 6, 2025
43402ac
Merge branch 'master' into precision_optional_parameter
PietropaoloFrisoni Sep 8, 2025
8da2a0e
Suggestions from code review
PietropaoloFrisoni Sep 8, 2025
62c9606
Specifying in the doc that MCMs are not supported with `dtype` argument
PietropaoloFrisoni Sep 9, 2025
c8ddaea
Merge branch 'master' into precision_optional_parameter
PietropaoloFrisoni Sep 9, 2025
7d20573
Update pennylane/measurements/sample.py
PietropaoloFrisoni Sep 11, 2025
7beebdf
Merge branch 'master' into precision_optional_parameter
PietropaoloFrisoni Sep 11, 2025
ab37fe5
Merge conflict
PietropaoloFrisoni Sep 11, 2025
001b10a
Removing `dtype` from `process_sample`
PietropaoloFrisoni Sep 11, 2025
d14f2f1
Merge branch 'master' into precision_optional_parameter
PietropaoloFrisoni Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

<h3>New features since last release</h3>

* The `qml.sample` function can now receive an optional `dtype` parameter
which sets the type and precision of the samples returned by this measurement process.
[(#8189)](https://github.com/PennyLaneAI/pennylane/pull/8189)

* Dynamic wire allocation with `qml.allocation.allocate` can now be executed on `default.qubit`.
[(#7718)](https://github.com/PennyLaneAI/pennylane/pull/7718)

Expand Down Expand Up @@ -111,7 +115,7 @@
<h3>Improvements 🛠</h3>

* `allocate` and `deallocate` can now be accessed as `qml.allocate` and `qml.deallocate`.
[(#8189)](https://github.com/PennyLaneAI/pennylane/pull/8198))
[(#8198)](https://github.com/PennyLaneAI/pennylane/pull/8198)

* `allocate` now takes `state: Literal["zero", "any"] = "zero"` instead of `require_zeros=True`.
[(#8163)](https://github.com/PennyLaneAI/pennylane/pull/8163)
Expand Down
17 changes: 13 additions & 4 deletions pennylane/measurements/process_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,40 @@

from pennylane import math
from pennylane.operation import EigvalsUndefinedError
from pennylane.typing import TensorLike
from pennylane.typing import Sequence, TensorLike
from pennylane.wires import WiresLike

from .measurement_value import MeasurementValue
from .measurements import MeasurementProcess


# pylint: disable=too-many-arguments
def process_raw_samples(
mp: MeasurementProcess, samples: TensorLike, wire_order, shot_range, bin_size
):
mp: MeasurementProcess,
samples: TensorLike,
wire_order: WiresLike,
shot_range: Sequence[int],
bin_size: int,
dtype=None,
) -> TensorLike:
"""Slice the samples for a measurement process.
Args:
mp (MeasurementProcess): the measurement process containing the wires, observable, and mcms for the processing
samples (TensorLike): the raw samples
wire_order: the wire order for the raw samples
wire_order (WiresLike): the wire order for the raw samples
shot_range (tuple[int]): 2-tuple of integers specifying the range of samples
to use. If not specified, all samples are used.
bin_size (int): Divides the shot range into bins of size ``bin_size``, and
returns the measurement statistic separately over each bin. If not
provided, the entire shot range is treated as a single bin.
dtype: The dtype of the samples returned by this measurement process.
This function matches `SampleMP.process_samples`, but does not have a dependence on the measurement process.
"""

samples = samples.astype(dtype) if dtype is not None else samples
wire_map = dict(zip(wire_order, range(len(wire_order))))
mapped_wires = [wire_map[w] for w in mp.wires]
# Select the samples from samples that correspond to ``shot_range`` if provided
Expand Down
97 changes: 81 additions & 16 deletions pennylane/measurements/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from pennylane.exceptions import MeasurementShapeError, QuantumFunctionError
from pennylane.operation import Operator
from pennylane.queuing import QueuingManager
from pennylane.wires import Wires
from pennylane.typing import TensorLike
from pennylane.wires import Wires, WiresLike

from .counts import CountsMP
from .measurements import SampleMeasurement
Expand All @@ -46,11 +47,15 @@ class SampleMP(SampleMeasurement):
This can only be specified if an observable was not provided.
id (str): custom label given to a measurement instance, can be useful for some applications
where the instance has to be identified
dtype (str or None): The dtype of the samples returned by this measurement process.
"""

_shortname = "sample"

def __init__(self, obs=None, wires=None, eigvals=None, id=None):
# pylint: disable=too-many-arguments
def __init__(self, obs=None, wires=None, eigvals=None, id=None, dtype=None):

self._dtype = dtype

if isinstance(obs, MeasurementValue):
super().__init__(obs=obs)
Expand Down Expand Up @@ -86,7 +91,7 @@ def _abstract_eval(
has_eigvals=False,
shots: int | None = None,
num_device_wires: int = 0,
):
) -> tuple[tuple[int, ...], type]:
if shots is None:
raise ValueError("finite shots are required to use SampleMP")
sample_eigvals = n_wires is None or has_eigvals
Expand All @@ -99,6 +104,8 @@ def _abstract_eval(

@property
def numeric_type(self):
if self._dtype is not None:
return self._dtype
if self.obs is None:
# Computational basis samples
return int
Expand All @@ -123,15 +130,22 @@ def shape(self, shots: int | None = None, num_device_wires: int = 0) -> tuple:
def process_samples(
self,
samples: Sequence[complex],
wire_order: Wires,
wire_order: WiresLike,
shot_range: None | tuple[int, ...] = None,
bin_size: None | int = None,
):
dtype=None,
) -> TensorLike:

return process_raw_samples(
self, samples, wire_order, shot_range=shot_range, bin_size=bin_size
self,
samples,
wire_order,
shot_range=shot_range,
bin_size=bin_size,
dtype=self._dtype if dtype is None else dtype,
)

def process_counts(self, counts: dict, wire_order: Wires):
def process_counts(self, counts: dict, wire_order: WiresLike) -> np.ndarray:
samples = []
mapped_counts = self._map_counts(counts, wire_order)
for outcome, count in mapped_counts.items():
Expand All @@ -143,11 +157,11 @@ def process_counts(self, counts: dict, wire_order: Wires):

return np.array(samples)

def _map_counts(self, counts_to_map, wire_order) -> dict:
def _map_counts(self, counts_to_map: dict, wire_order: WiresLike) -> dict:
"""
Args:
counts_to_map: Dictionary where key is binary representation of the outcome and value is its count
wire_order: Order of wires to which counts_to_map should be ordered in
counts_to_map (dict): Dictionary where key is binary representation of the outcome and value is its count
wire_order (WiresLike): Order of wires to which counts_to_map should be ordered in

Returns:
Dictionary where counts_to_map has been reordered according to wire_order
Expand All @@ -156,7 +170,7 @@ def _map_counts(self, counts_to_map, wire_order) -> dict:
helper_counts = CountsMP(wires=self.wires, all_outcomes=False)
return helper_counts.process_counts(counts_to_map, wire_order)

def _compute_outcome_sample(self, outcome) -> list:
def _compute_outcome_sample(self, outcome: str) -> list:
"""
Args:
outcome (str): The binary string representation of the measurement outcome.
Expand All @@ -174,11 +188,12 @@ def _compute_outcome_sample(self, outcome) -> list:

def sample(
op: Operator | MeasurementValue | Sequence[MeasurementValue] | None = None,
wires=None,
wires: WiresLike = None,
dtype=None,
) -> SampleMP:
r"""Sample from the supplied observable, with the number of shots
determined from QNode,
returning raw samples. If no observable is provided then basis state samples are returned
returning raw samples. If no observable is provided, then basis state samples are returned
directly from the device.

Note that the output shape of this measurement process depends on the shots
Expand All @@ -189,6 +204,7 @@ def sample(
for mid-circuit measurements, ``op`` should be a ``MeasurementValue``.
wires (Sequence[int] or int or None): the wires we wish to sample from; ONLY set wires if
op is ``None``.
dtype: The dtype of the samples returned by this measurement process.

Returns:
SampleMP: Measurement process instance
Expand Down Expand Up @@ -292,8 +308,8 @@ def circuit(x):
array([ 1., 1., 1., -1.])

If no observable is provided, then the raw basis state samples obtained
from device are returned (e.g., for a qubit device, samples from the
computational device are returned). In this case, ``wires`` can be specified
from the device are returned (e.g., for a qubit device, samples from the
computational basis are returned). In this case, ``wires`` can be specified
so that sample results only include measurement results of the qubits of interest.

.. code-block:: python3
Expand All @@ -317,5 +333,54 @@ def circuit(x):
[1, 1],
[0, 0]])

.. details::
:title: Setting the precision of the samples

The ``dtype`` argument can be used to set the type and precision of the samples returned by this measurement process
when the ``op`` argument does not contain mid-circuit measurements. Otherwise, the ``dtype`` argument is ignored.

By default, the samples will be returned as floating point numbers if an observable is provided,
and as integers if no observable is provided. The ``dtype`` argument can be used to override this default behavior,
and set the precision to any valid interface-like dtype, e.g. ``'float32'``, ``'int8'``, ``'uint16'``, etc.

We show two examples below using the JAX and PyTorch interfaces.
This argument is compatible with all interfaces currently supported by PennyLane.

**Example:**

.. code-block:: python3

@qml.set_shots(1000000)
@qml.qnode(qml.device("default.qubit", wires=1), interface="jax")
def circuit():
qml.Hadamard(0)
return qml.sample(dtype="int8")

Executing this QNode, we get:

>>> samples = circuit()
>>> samples.dtype
dtype('int8')
>>> type(samples)
jaxlib._jax.ArrayImpl

If an observable is provided, the samples will be floating point numbers:

.. code-block:: python3

@qml.set_shots(1000000)
@qml.qnode(qml.device("default.qubit", wires=1), interface="torch")
def circuit():
qml.Hadamard(0)
return qml.sample(qml.Z(0), dtype="float32")

Executing this QNode, we get:

>>> samples = circuit()
>>> samples.dtype
torch.float32
>>> type(samples)
torch.Tensor

"""
return SampleMP(obs=op, wires=None if wires is None else Wires(wires))
return SampleMP(obs=op, wires=None if wires is None else Wires(wires), dtype=dtype)
63 changes: 63 additions & 0 deletions tests/measurements/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,19 @@ def test_new_sample_with_operator_with_no_eigvals(self):
"""Test that calling process with an operator that has no eigvals defined raises an error."""

class DummyOp(Operator): # pylint: disable=too-few-public-methods
"""Dummy operator with no eigenvalues defined."""

num_wires = 1

with pytest.raises(EigvalsUndefinedError, match="Cannot compute samples of"):
qml.sample(op=DummyOp(0)).process_samples(samples=np.array([[1, 0]]), wire_order=[0])

def test_process_samples_dtype(self):
"""Test that the dtype argument changes the dtype of the returned samples."""
samples = np.zeros(10, dtype="int64")
processed_samples = qml.sample().process_samples(samples, wire_order=[0], dtype="int8")
assert processed_samples.dtype == np.dtype("int8")

def test_sample_allowed_with_parameter_shift(self):
"""Test that qml.sample doesn't raise an error with parameter-shift and autograd."""
dev = qml.device("default.qubit")
Expand Down Expand Up @@ -437,9 +445,38 @@ def circuit(angle):
angle = jax.numpy.array(0.1)
_ = jax.jacobian(circuit)(angle)

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["autograd", "torch", "jax"])
@pytest.mark.parametrize(
"dtype, obs",
[
("int8", None),
("int16", None),
("int32", None),
("int64", None),
("float16", qml.Z(0)),
("float32", qml.Z(0)),
("float64", qml.Z(0)),
],
)
def test_sample_dtype_combined(self, interface, dtype, obs):
"""Test that the dtype argument changes the dtype of the returned samples,
both with and without an observable."""

@qml.set_shots(10)
@qml.qnode(device=qml.device("default.qubit", wires=1), interface=interface)
def circuit():
qml.Hadamard(wires=0)
return qml.sample(obs, dtype=dtype) if obs is not None else qml.sample(dtype=dtype)

samples = circuit()
assert qml.math.get_interface(samples) == interface
assert qml.math.get_dtype_name(samples) == dtype


@pytest.mark.jax
class TestJAXCompatibility:
"""Tests for JAX compatibility"""

@pytest.mark.parametrize("samples", (1, 10))
def test_jitting_with_sampling_on_subset_of_wires(self, samples):
Expand Down Expand Up @@ -510,6 +547,32 @@ def circuit(x):
assert results.dtype == jax.numpy.float64
assert np.all([r in [1, -1] for r in results])

@pytest.mark.parametrize(
"dtype, obs",
[
("int8", None),
("int16", None),
("int32", None),
("int64", None),
("float16", qml.Z(0)),
("float32", qml.Z(0)),
("float64", qml.Z(0)),
],
)
def test_jitting_with_dtype(self, dtype, obs):
"""Test that jitting works when the dtype argument is provided"""
import jax

@qml.set_shots(10)
@qml.qnode(device=qml.device("default.qubit", wires=1), interface="jax")
def circuit(x):
qml.RX(x, wires=0)
return qml.sample(obs, dtype=dtype) if obs is not None else qml.sample(dtype=dtype)

samples = jax.jit(circuit)(jax.numpy.array(0.123))
assert qml.math.get_interface(samples) == "jax"
assert qml.math.get_dtype_name(samples) == dtype

def test_process_samples_with_jax_tracer(self):
"""Test that qml.sample can be used when samples is a JAX Tracer"""

Expand Down