Skip to content

Commit 73a4e78

Browse files
SampleMP include a dtype optional parameter (#8189)
**Context:** This PR adds the `dtype` optional parameter for the `SampleMP` measurement process. **Description of the Change:** As above. **Benefits:** More flexibility and very small (mostly symbolic) speed-up improvement **Possible Drawbacks:** None that I can think of. **Related GitHub Issues:** None. [sc-98358]
1 parent 9dc776e commit 73a4e78

File tree

4 files changed

+156
-21
lines changed

4 files changed

+156
-21
lines changed

doc/releases/changelog-dev.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
<h3>New features since last release</h3>
55

6+
* The `qml.sample` function can now receive an optional `dtype` parameter
7+
which sets the type and precision of the samples returned by this measurement process.
8+
[(#8189)](https://github.com/PennyLaneAI/pennylane/pull/8189)
9+
610
* The Resource estimation toolkit was upgraded and has migrated from
711
:mod:`~.labs` to PennyLane as the :mod:`~.estimator` module.
812

@@ -139,7 +143,7 @@
139143
[(#8229)](https://github.com/PennyLaneAI/pennylane/pull/8229)
140144

141145
* `allocate` and `deallocate` can now be accessed as `qml.allocate` and `qml.deallocate`.
142-
[(#8189)](https://github.com/PennyLaneAI/pennylane/pull/8198)
146+
[(#8198)](https://github.com/PennyLaneAI/pennylane/pull/8198)
143147

144148
* `allocate` now takes `state: Literal["zero", "any"] = "zero"` instead of `require_zeros=True`.
145149
[(#8163)](https://github.com/PennyLaneAI/pennylane/pull/8163)

pennylane/measurements/process_samples.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,40 @@
1717

1818
from pennylane import math
1919
from pennylane.operation import EigvalsUndefinedError
20-
from pennylane.typing import TensorLike
20+
from pennylane.typing import Sequence, TensorLike
21+
from pennylane.wires import WiresLike
2122

2223
from .measurement_value import MeasurementValue
2324
from .measurements import MeasurementProcess
2425

2526

27+
# pylint: disable=too-many-arguments
2628
def process_raw_samples(
27-
mp: MeasurementProcess, samples: TensorLike, wire_order, shot_range, bin_size
28-
):
29+
mp: MeasurementProcess,
30+
samples: TensorLike,
31+
wire_order: WiresLike,
32+
shot_range: Sequence[int],
33+
bin_size: int,
34+
dtype=None,
35+
) -> TensorLike:
2936
"""Slice the samples for a measurement process.
3037
3138
Args:
3239
mp (MeasurementProcess): the measurement process containing the wires, observable, and mcms for the processing
3340
samples (TensorLike): the raw samples
34-
wire_order: the wire order for the raw samples
41+
wire_order (WiresLike): the wire order for the raw samples
3542
shot_range (tuple[int]): 2-tuple of integers specifying the range of samples
3643
to use. If not specified, all samples are used.
3744
bin_size (int): Divides the shot range into bins of size ``bin_size``, and
3845
returns the measurement statistic separately over each bin. If not
3946
provided, the entire shot range is treated as a single bin.
47+
dtype: The dtype of the samples returned by this measurement process.
4048
4149
This function matches `SampleMP.process_samples`, but does not have a dependence on the measurement process.
4250
4351
"""
4452

53+
samples = samples.astype(dtype) if dtype is not None else samples
4554
wire_map = dict(zip(wire_order, range(len(wire_order))))
4655
mapped_wires = [wire_map[w] for w in mp.wires]
4756
# Select the samples from samples that correspond to ``shot_range`` if provided

pennylane/measurements/sample.py

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from pennylane.exceptions import MeasurementShapeError, QuantumFunctionError
2323
from pennylane.operation import Operator
2424
from pennylane.queuing import QueuingManager
25-
from pennylane.wires import Wires
25+
from pennylane.typing import TensorLike
26+
from pennylane.wires import Wires, WiresLike
2627

2728
from .counts import CountsMP
2829
from .measurements import SampleMeasurement
@@ -46,11 +47,15 @@ class SampleMP(SampleMeasurement):
4647
This can only be specified if an observable was not provided.
4748
id (str): custom label given to a measurement instance, can be useful for some applications
4849
where the instance has to be identified
50+
dtype (str or None): The dtype of the samples returned by this measurement process.
4951
"""
5052

5153
_shortname = "sample"
5254

53-
def __init__(self, obs=None, wires=None, eigvals=None, id=None):
55+
# pylint: disable=too-many-arguments
56+
def __init__(self, obs=None, wires=None, eigvals=None, id=None, dtype=None):
57+
58+
self._dtype = dtype
5459

5560
if isinstance(obs, MeasurementValue):
5661
super().__init__(obs=obs)
@@ -86,7 +91,7 @@ def _abstract_eval(
8691
has_eigvals=False,
8792
shots: int | None = None,
8893
num_device_wires: int = 0,
89-
):
94+
) -> tuple[tuple[int, ...], type]:
9095
if shots is None:
9196
raise ValueError("finite shots are required to use SampleMP")
9297
sample_eigvals = n_wires is None or has_eigvals
@@ -99,6 +104,8 @@ def _abstract_eval(
99104

100105
@property
101106
def numeric_type(self):
107+
if self._dtype is not None:
108+
return self._dtype
102109
if self.obs is None:
103110
# Computational basis samples
104111
return int
@@ -123,15 +130,16 @@ def shape(self, shots: int | None = None, num_device_wires: int = 0) -> tuple:
123130
def process_samples(
124131
self,
125132
samples: Sequence[complex],
126-
wire_order: Wires,
133+
wire_order: WiresLike,
127134
shot_range: None | tuple[int, ...] = None,
128135
bin_size: None | int = None,
129-
):
136+
) -> TensorLike:
137+
130138
return process_raw_samples(
131-
self, samples, wire_order, shot_range=shot_range, bin_size=bin_size
139+
self, samples, wire_order, shot_range=shot_range, bin_size=bin_size, dtype=self._dtype
132140
)
133141

134-
def process_counts(self, counts: dict, wire_order: Wires):
142+
def process_counts(self, counts: dict, wire_order: WiresLike) -> np.ndarray:
135143
samples = []
136144
mapped_counts = self._map_counts(counts, wire_order)
137145
for outcome, count in mapped_counts.items():
@@ -143,11 +151,11 @@ def process_counts(self, counts: dict, wire_order: Wires):
143151

144152
return np.array(samples)
145153

146-
def _map_counts(self, counts_to_map, wire_order) -> dict:
154+
def _map_counts(self, counts_to_map: dict, wire_order: WiresLike) -> dict:
147155
"""
148156
Args:
149-
counts_to_map: Dictionary where key is binary representation of the outcome and value is its count
150-
wire_order: Order of wires to which counts_to_map should be ordered in
157+
counts_to_map (dict): Dictionary where key is binary representation of the outcome and value is its count
158+
wire_order (WiresLike): Order of wires to which counts_to_map should be ordered in
151159
152160
Returns:
153161
Dictionary where counts_to_map has been reordered according to wire_order
@@ -156,7 +164,7 @@ def _map_counts(self, counts_to_map, wire_order) -> dict:
156164
helper_counts = CountsMP(wires=self.wires, all_outcomes=False)
157165
return helper_counts.process_counts(counts_to_map, wire_order)
158166

159-
def _compute_outcome_sample(self, outcome) -> list:
167+
def _compute_outcome_sample(self, outcome: str) -> list:
160168
"""
161169
Args:
162170
outcome (str): The binary string representation of the measurement outcome.
@@ -174,11 +182,12 @@ def _compute_outcome_sample(self, outcome) -> list:
174182

175183
def sample(
176184
op: Operator | MeasurementValue | Sequence[MeasurementValue] | None = None,
177-
wires=None,
185+
wires: WiresLike = None,
186+
dtype=None,
178187
) -> SampleMP:
179188
r"""Sample from the supplied observable, with the number of shots
180189
determined from QNode,
181-
returning raw samples. If no observable is provided then basis state samples are returned
190+
returning raw samples. If no observable is provided, then basis state samples are returned
182191
directly from the device.
183192
184193
Note that the output shape of this measurement process depends on the shots
@@ -189,6 +198,7 @@ def sample(
189198
for mid-circuit measurements, ``op`` should be a ``MeasurementValue``.
190199
wires (Sequence[int] or int or None): the wires we wish to sample from; ONLY set wires if
191200
op is ``None``.
201+
dtype: The dtype of the samples returned by this measurement process.
192202
193203
Returns:
194204
SampleMP: Measurement process instance
@@ -292,8 +302,8 @@ def circuit(x):
292302
array([ 1., 1., 1., -1.])
293303
294304
If no observable is provided, then the raw basis state samples obtained
295-
from device are returned (e.g., for a qubit device, samples from the
296-
computational device are returned). In this case, ``wires`` can be specified
305+
from the device are returned (e.g., for a qubit device, samples from the
306+
computational basis are returned). In this case, ``wires`` can be specified
297307
so that sample results only include measurement results of the qubits of interest.
298308
299309
.. code-block:: python3
@@ -317,5 +327,54 @@ def circuit(x):
317327
[1, 1],
318328
[0, 0]])
319329
330+
.. details::
331+
:title: Setting the precision of the samples
332+
333+
The ``dtype`` argument can be used to set the type and precision of the samples returned by this measurement process
334+
when the ``op`` argument does not contain mid-circuit measurements. Otherwise, the ``dtype`` argument is ignored.
335+
336+
By default, the samples will be returned as floating point numbers if an observable is provided,
337+
and as integers if no observable is provided. The ``dtype`` argument can be used to specify further details,
338+
and set the precision to any valid interface-like dtype, e.g. ``'float32'``, ``'int8'``, ``'uint16'``, etc.
339+
340+
We show two examples below using the JAX and PyTorch interfaces.
341+
This argument is compatible with all interfaces currently supported by PennyLane.
342+
343+
**Example:**
344+
345+
.. code-block:: python3
346+
347+
@qml.set_shots(1000000)
348+
@qml.qnode(qml.device("default.qubit", wires=1), interface="jax")
349+
def circuit():
350+
qml.Hadamard(0)
351+
return qml.sample(dtype="int8")
352+
353+
Executing this QNode, we get:
354+
355+
>>> samples = circuit()
356+
>>> samples.dtype
357+
dtype('int8')
358+
>>> type(samples)
359+
jaxlib._jax.ArrayImpl
360+
361+
If an observable is provided, the samples will be floating point numbers:
362+
363+
.. code-block:: python3
364+
365+
@qml.set_shots(1000000)
366+
@qml.qnode(qml.device("default.qubit", wires=1), interface="torch")
367+
def circuit():
368+
qml.Hadamard(0)
369+
return qml.sample(qml.Z(0), dtype="float32")
370+
371+
Executing this QNode, we get:
372+
373+
>>> samples = circuit()
374+
>>> samples.dtype
375+
torch.float32
376+
>>> type(samples)
377+
torch.Tensor
378+
320379
"""
321-
return SampleMP(obs=op, wires=None if wires is None else Wires(wires))
380+
return SampleMP(obs=op, wires=None if wires is None else Wires(wires), dtype=dtype)

tests/measurements/test_sample.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,19 @@ def test_new_sample_with_operator_with_no_eigvals(self):
401401
"""Test that calling process with an operator that has no eigvals defined raises an error."""
402402

403403
class DummyOp(Operator): # pylint: disable=too-few-public-methods
404+
"""Dummy operator with no eigenvalues defined."""
405+
404406
num_wires = 1
405407

406408
with pytest.raises(EigvalsUndefinedError, match="Cannot compute samples of"):
407409
qml.sample(op=DummyOp(0)).process_samples(samples=np.array([[1, 0]]), wire_order=[0])
408410

411+
def test_process_samples_dtype(self):
412+
"""Test that the dtype argument changes the dtype of the returned samples."""
413+
samples = np.zeros(10, dtype="int64")
414+
processed_samples = qml.sample(dtype="int8").process_samples(samples, wire_order=[0])
415+
assert processed_samples.dtype == np.dtype("int8")
416+
409417
def test_sample_allowed_with_parameter_shift(self):
410418
"""Test that qml.sample doesn't raise an error with parameter-shift and autograd."""
411419
dev = qml.device("default.qubit")
@@ -437,9 +445,38 @@ def circuit(angle):
437445
angle = jax.numpy.array(0.1)
438446
_ = jax.jacobian(circuit)(angle)
439447

448+
@pytest.mark.all_interfaces
449+
@pytest.mark.parametrize("interface", ["autograd", "torch", "jax"])
450+
@pytest.mark.parametrize(
451+
"dtype, obs",
452+
[
453+
("int8", None),
454+
("int16", None),
455+
("int32", None),
456+
("int64", None),
457+
("float16", qml.Z(0)),
458+
("float32", qml.Z(0)),
459+
("float64", qml.Z(0)),
460+
],
461+
)
462+
def test_sample_dtype_combined(self, interface, dtype, obs):
463+
"""Test that the dtype argument changes the dtype of the returned samples,
464+
both with and without an observable."""
465+
466+
@qml.set_shots(10)
467+
@qml.qnode(device=qml.device("default.qubit", wires=1), interface=interface)
468+
def circuit():
469+
qml.Hadamard(wires=0)
470+
return qml.sample(obs, dtype=dtype) if obs is not None else qml.sample(dtype=dtype)
471+
472+
samples = circuit()
473+
assert qml.math.get_interface(samples) == interface
474+
assert qml.math.get_dtype_name(samples) == dtype
475+
440476

441477
@pytest.mark.jax
442478
class TestJAXCompatibility:
479+
"""Tests for JAX compatibility"""
443480

444481
@pytest.mark.parametrize("samples", (1, 10))
445482
def test_jitting_with_sampling_on_subset_of_wires(self, samples):
@@ -510,6 +547,32 @@ def circuit(x):
510547
assert results.dtype == jax.numpy.float64
511548
assert np.all([r in [1, -1] for r in results])
512549

550+
@pytest.mark.parametrize(
551+
"dtype, obs",
552+
[
553+
("int8", None),
554+
("int16", None),
555+
("int32", None),
556+
("int64", None),
557+
("float16", qml.Z(0)),
558+
("float32", qml.Z(0)),
559+
("float64", qml.Z(0)),
560+
],
561+
)
562+
def test_jitting_with_dtype(self, dtype, obs):
563+
"""Test that jitting works when the dtype argument is provided"""
564+
import jax
565+
566+
@qml.set_shots(10)
567+
@qml.qnode(device=qml.device("default.qubit", wires=1), interface="jax")
568+
def circuit(x):
569+
qml.RX(x, wires=0)
570+
return qml.sample(obs, dtype=dtype) if obs is not None else qml.sample(dtype=dtype)
571+
572+
samples = jax.jit(circuit)(jax.numpy.array(0.123))
573+
assert qml.math.get_interface(samples) == "jax"
574+
assert qml.math.get_dtype_name(samples) == dtype
575+
513576
def test_process_samples_with_jax_tracer(self):
514577
"""Test that qml.sample can be used when samples is a JAX Tracer"""
515578

0 commit comments

Comments
 (0)