Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c3b4d1f
PLDW in kwargs
JerryChen97 Jul 17, 2025
10ab7d5
log?
JerryChen97 Jul 17, 2025
34a8fea
rst
JerryChen97 Jul 17, 2025
b61fe37
mf
JerryChen97 Jul 18, 2025
434353a
rm the PLDW in init to allow user qfunc definition using shots=
JerryChen97 Jul 18, 2025
9fd697b
adjust tests
JerryChen97 Jul 18, 2025
d4a71e0
Merge branch 'master' into shots-decoupling/deprecate_qnode_call_shots
JerryChen97 Jul 18, 2025
70a01cd
Update pennylane/workflow/qnode.py
JerryChen97 Jul 21, 2025
2cd7ce6
Merge branch 'master' into shots-decoupling/deprecate_qnode_call_shots
JerryChen97 Jul 21, 2025
ee96313
correct the deprecation logic
JerryChen97 Jul 21, 2025
cdde1eb
fix test_default_qubit.py
JerryChen97 Jul 21, 2025
5c5506e
Merge branch 'master' into shots-decoupling/deprecate_qnode_call_shots
JerryChen97 Jul 21, 2025
d0c159d
debug
JerryChen97 Jul 21, 2025
55fbdd0
dynamic shots available now
JerryChen97 Jul 21, 2025
813af2d
debug three files
JerryChen97 Jul 21, 2025
95668ff
wish no tensorflow in paradise
JerryChen97 Jul 21, 2025
4290b81
fix jax qnode
JerryChen97 Jul 21, 2025
f86037e
some leftover
JerryChen97 Jul 21, 2025
2a9a3f2
fix qcut
JerryChen97 Jul 21, 2025
0100ba4
debugging test debugging
JerryChen97 Jul 21, 2025
b352122
qcut leftover
JerryChen97 Jul 21, 2025
0cd3186
legacy
JerryChen97 Jul 21, 2025
f92c0d2
Fix optimize issue
JerryChen97 Jul 21, 2025
812feb7
fix defer measurement
JerryChen97 Jul 21, 2025
b79d8ac
rm unused partial
JerryChen97 Jul 21, 2025
049f4a1
fix two more
JerryChen97 Jul 21, 2025
8c4d7ce
try import from workflow?
JerryChen97 Jul 21, 2025
dd000f4
Merge branch 'master' into shots-decoupling/deprecate_qnode_call_shots
JerryChen97 Jul 22, 2025
2f7c206
Update pennylane/workflow/qnode.py
JerryChen97 Jul 22, 2025
5a1c6b9
Apply suggestions from code review
JerryChen97 Jul 22, 2025
841212e
remove a test that's essentially legacy test
JerryChen97 Jul 22, 2025
ee4b276
update msg tsrings
JerryChen97 Jul 22, 2025
bc27fd9
dont' xfail
JerryChen97 Jul 22, 2025
2a164a6
no xfail 2
JerryChen97 Jul 22, 2025
0856455
Merge branch 'master' into shots-decoupling/deprecate_qnode_call_shots
JerryChen97 Jul 22, 2025
b6a4119
Merge branch 'master' into shots-decoupling/deprecate_qnode_call_shots
JerryChen97 Jul 22, 2025
10c0aa2
Merge branch 'master' into shots-decoupling/deprecate_qnode_call_shots
JerryChen97 Jul 23, 2025
25aadaa
Merge branch 'master' into shots-decoupling/deprecate_qnode_call_shots
JerryChen97 Jul 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ deprecations are listed below.
Pending deprecations
--------------------

* ``shots=`` in ``QNode`` calls is deprecated and will be removed in v0.44.
Instead, please use the ``qml.workflow.set_shots`` transform to set the number of shots for a ``QNode``.

- Deprecated in v0.43
- Will be removed in v0.44

* The ``QuantumScript.to_openqasm`` method is deprecated and will be removed in version v0.44.
Instead, the ``qml.to_openqasm`` function should be used.

Expand Down
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@

<h3>Deprecations 👋</h3>

* `shots=` in `QNode` calls is deprecated and will be removed in v0.44.
Instead, please use the `qml.workflow.set_shots` transform to set the number of shots for a QNode.
[(#7906)](https://github.com/PennyLaneAI/pennylane/pull/7906)

* The `QuantumScript.to_openqasm` method is deprecated and will be removed in version v0.44.
Instead, the `qml.to_openqasm` function should be used.
[(#7909)](https://github.com/PennyLaneAI/pennylane/pull/7909)
Expand Down
13 changes: 7 additions & 6 deletions pennylane/optimize/shot_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pennylane.ops import LinearCombination
from pennylane.queuing import apply
from pennylane.tape import make_qscript
from pennylane.workflow import QNode, construct_tape
from pennylane.workflow import QNode, construct_tape, set_shots

from .gradient_descent import GradientDescentOptimizer

Expand Down Expand Up @@ -259,10 +259,11 @@ def func(*qnode_args, **qnode_kwargs):
apply(op)
return measurements.expval(o) # pylint:disable=cell-var-from-loop

qnode.func = func

new_shots = 1 if s == 1 else [(1, int(s))]

qnode.func = func
qnode = set_shots(qnode, shots=new_shots)

if s > 1:

def cost(*args, **kwargs):
Expand All @@ -272,7 +273,7 @@ def cost(*args, **kwargs):
else:
cost = qnode

jacs = jacobian(cost, argnum=argnums)(*args, **kwargs, shots=new_shots)
jacs = jacobian(cost, argnum=argnums)(*args, **kwargs)

if s == 1:
jacs = [np.expand_dims(j, 0) for j in jacs]
Expand Down Expand Up @@ -344,7 +345,7 @@ def _single_shot_qnode_gradients(self, qnode, args, kwargs):
new_shots = [(1, int(self.max_shots))]

def cost(*args, **kwargs):
return math.stack(qnode(*args, **kwargs, shots=new_shots))
return math.stack(set_shots(qnode, shots=new_shots)(*args, **kwargs))

grads = [jacobian(cost, argnum=i)(*args, **kwargs) for i in self.trainable_args]

Expand Down Expand Up @@ -499,5 +500,5 @@ def step_and_cost(self, objective_fn, *args, **kwargs):
If single arg is provided, list [array] is replaced by array.
"""
new_args = self.step(objective_fn, *args, **kwargs)
forward = objective_fn(*args, **kwargs, shots=int(self.max_shots))
forward = set_shots(objective_fn, shots=int(self.max_shots))(*args, **kwargs)
return new_args, forward
17 changes: 12 additions & 5 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,14 +787,21 @@ def _set_shots(self, shots: int | Shots) -> None:
def construct(self, args, kwargs) -> qml.tape.QuantumScript:
"""Call the quantum function with a tape context, ensuring the operations get queued."""
kwargs = copy.copy(kwargs)
if "shots" in kwargs and self._shots_override_device:
_kwargs_shots = kwargs.pop("shots")
if "shots" in kwargs:
# NOTE: at removal, remember to remove the userwarning below as well
warnings.warn(
"Both 'shots=' parameter and 'set_shots' transform are specified. "
f"The transform will take precedence over 'shots={_kwargs_shots}.'",
UserWarning,
"'shots' specified on call to a QNode is deprecated and will be removed in v0.44. Use qml.set_shots instead.",
PennyLaneDeprecationWarning,
stacklevel=2,
)
if self._shots_override_device:
_kwargs_shots = kwargs.pop("shots")
warnings.warn(
"Both 'shots=' parameter and 'set_shots' transform are specified. "
f"The transform will take precedence over 'shots={_kwargs_shots}.'",
UserWarning,
stacklevel=2,
)

if self._qfunc_uses_shots_arg or self._shots_override_device: # QNode._shots precedency:
shots = self._shots
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/set_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ def decorator(qnode_func):
return qnode.update_shots(shots)

# If qnode is not a QNode (including explicit None), raise error
raise ValueError("set_shots can only be applied to QNodes")
raise ValueError(f"set_shots can only be applied to QNodes, not {type(qnode)} provided.")
12 changes: 8 additions & 4 deletions tests/devices/default_qubit/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,6 +1886,7 @@ def test_postselection_valid_finite_shots(self, param, mp, shots, interface, use
dev = qml.device("default.qubit", seed=seed)
param = qml.math.asarray(param, like=interface)

@qml.set_shots(shots=shots)
@qml.defer_measurements
@qml.qnode(dev, interface=interface)
def circ_postselect(theta):
Expand All @@ -1894,6 +1895,7 @@ def circ_postselect(theta):
qml.measure(0, postselect=1)
return qml.apply(mp)

@qml.set_shots(shots=shots)
@qml.defer_measurements
@qml.qnode(dev, interface=interface)
def circ_expected():
Expand All @@ -1907,8 +1909,8 @@ def circ_expected():
pytest.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
circ_postselect = jax.jit(circ_postselect, static_argnames=["shots"])

res = circ_postselect(param, shots=shots)
expected = circ_expected(shots=shots)
res = circ_postselect(param)
expected = circ_expected()

if not isinstance(shots, tuple):
assert qml.math.allclose(res, expected, atol=0.1, rtol=0)
Expand Down Expand Up @@ -1949,6 +1951,7 @@ def test_postselection_valid_finite_shots_varied_shape(

with mock.patch("numpy.random.binomial", lambda *args, **kwargs: 5):

@qml.set_shots(shots=shots)
@qml.defer_measurements
@qml.qnode(dev, interface=interface)
def circ_postselect(theta):
Expand All @@ -1957,7 +1960,7 @@ def circ_postselect(theta):
qml.measure(0, postselect=1)
return qml.apply(mp)

res = circ_postselect(param, shots=shots)
res = circ_postselect(param)

if not isinstance(shots, tuple):
assert qml.math.get_interface(res) == interface if interface != "autograd" else "numpy"
Expand Down Expand Up @@ -2081,6 +2084,7 @@ def test_postselection_invalid_finite_shots(

dev = qml.device("default.qubit")

@qml.set_shots(shots=shots)
@qml.defer_measurements
@qml.qnode(dev, interface=interface)
def circ():
Expand All @@ -2095,7 +2099,7 @@ def circ():
pytest.xfail(reason="'shots' cannot be a static_argname for 'jit' in JAX 0.4.28")
circ = jax.jit(circ, static_argnames=["shots"])

res = circ(shots=shots)
res = circ()

if not isinstance(shots, tuple):
assert qml.math.shape(res) == expected_shape
Expand Down
8 changes: 5 additions & 3 deletions tests/gradients/core/test_gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,9 @@ def test_setting_shots(self):
"""Test that setting the number of shots works correctly for
a gradient transform"""

dev = qml.device("default.qubit", wires=1, shots=1000)
dev = qml.device("default.qubit", wires=1)

@qml.set_shots(shots=1000)
@qml.qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
Expand All @@ -649,11 +650,12 @@ def circuit(x):
# the gradient function can be called with different shot values
grad_fn = qml.gradients.param_shift(circuit)
assert grad_fn(x).shape == ()
assert len(grad_fn(x, shots=[(1, 1000)])) == 1000

assert len(qml.set_shots(shots=[(1, 1000)])(grad_fn)(x)) == 1000

# the original QNode is unaffected
assert circuit(x).shape == tuple()
assert circuit(x, shots=1000).shape == tuple()
assert qml.set_shots(shots=1000)(circuit)(x).shape == tuple()

@pytest.mark.parametrize(
"interface",
Expand Down
3 changes: 2 additions & 1 deletion tests/measurements/test_expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_observable_is_composite_measurement_value(
are correct for a composite measurement value."""
dev = qml.device("default.qubit", seed=seed)

@qml.set_shots(shots=shots)
@qml.qnode(dev)
def circuit(phi):
qml.RX(phi, 0)
Expand All @@ -171,7 +172,7 @@ def expected_circuit(phi):

atol = tol if shots is None else tol_stochastic
for func in [circuit, qml.defer_measurements(circuit)]:
res = func(phi, shots=shots)
res = func(phi)
assert np.allclose(np.array(res), expected, atol=atol, rtol=0)

def test_eigvals_instead_of_observable(self, seed):
Expand Down
3 changes: 2 additions & 1 deletion tests/measurements/test_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def test_observable_is_measurement_value_list(

dev = qml.device("default.qubit", seed=seed)

@qml.set_shots(shots=shots)
@qml.qnode(dev)
def circuit(phi):
qml.RX(phi, 0)
Expand All @@ -427,7 +428,7 @@ def circuit(phi):
m2 = qml.measure(2)
return qml.probs(op=[m0, m1, m2])

res = circuit(phi, shots=shots)
res = circuit(phi)

@qml.qnode(dev)
def expected_circuit(phi):
Expand Down
3 changes: 2 additions & 1 deletion tests/measurements/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_observable_is_composite_measurement_value(
are correct for a composite measurement value."""
dev = qml.device("default.qubit", seed=seed)

@qml.set_shots(shots=shots)
@qml.qnode(dev)
def circuit(phi):
qml.RX(phi, 0)
Expand Down Expand Up @@ -105,7 +106,7 @@ def expected_circuit(phi):

atol = tol if shots is None else tol_stochastic
for func in [circuit, qml.defer_measurements(circuit)]:
res = func(phi, shots=shots)
res = func(phi)
assert np.allclose(np.array(res), expected, atol=atol, rtol=0)

def test_eigvals_instead_of_observable(self, seed):
Expand Down
30 changes: 23 additions & 7 deletions tests/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pennylane as qml
from pennylane import numpy as qnp
from pennylane.debugging import PLDB, pldb_device_manager
from pennylane.exceptions import DeviceError, QuantumFunctionError
from pennylane.exceptions import DeviceError, PennyLaneDeprecationWarning, QuantumFunctionError
from pennylane.ops.functions.equal import assert_equal


Expand Down Expand Up @@ -156,6 +156,7 @@ def circuit():
qml.snapshots(circuit)()

def test_non_StateMP_state_measurements_with_finite_shot_device_fails(self, dev):
@qml.set_shots(shots=200)
@qml.qnode(dev)
def circuit():
qml.Hadamard(0)
Expand All @@ -165,12 +166,13 @@ def circuit():
# Expect a DeviceError to be raised here since no shots has
# been provided to the snapshot due to the finite-shot device
with pytest.raises(DeviceError):
qml.snapshots(circuit)(shots=200)
qml.snapshots(circuit)()

def test_StateMP_with_finite_shot_device_passes(self, dev):
if "lightning" in dev.name or "mixed" in dev.name:
pytest.skip()

@qml.set_shots(shots=200)
@qml.qnode(dev)
def circuit():
qml.Snapshot(measurement=qml.state())
Expand All @@ -181,7 +183,7 @@ def circuit():

return qml.expval(qml.PauliZ(0))

_ = qml.snapshots(circuit)(shots=200)
_ = qml.snapshots(circuit)()

@pytest.mark.parametrize("diff_method", [None, "parameter-shift"])
def test_all_state_measurement_snapshot_pure_qubit_dev(self, dev, diff_method):
Expand Down Expand Up @@ -437,7 +439,7 @@ def circuit(add_bad_snapshot: bool):
assert result["execution_results"] == expected["execution_results"]

# Make sure shots are overridden correctly
result = qml.snapshots(circuit)(add_bad_snapshot=False, shots=200)
result = qml.snapshots(qml.set_shots(shots=200)(circuit))(add_bad_snapshot=False)
assert result[0] == {"00": 74, "10": 58, "20": 68}

@pytest.mark.parametrize(
Expand Down Expand Up @@ -591,7 +593,11 @@ def circuit():
_compare_numpy_dicts(result, expected)

# Make sure shots are overridden correctly
result = qml.snapshots(circuit)(shots=200)
with pytest.warns(
PennyLaneDeprecationWarning,
match="'shots' as an argument to the quantum function is deprecated",
):
result = qml.snapshots(circuit)(shots=200)
assert result[3] == {"0": 98, "1": 102}
assert np.allclose(result[5], expected[5])

Expand Down Expand Up @@ -648,7 +654,13 @@ def circuit():
assert ttest_ind(expvals, 0.0).pvalue >= 0.75

# Make sure shots are overridden correctly
counts, _ = tuple(zip(*(qml.snapshots(circuit)(shots=1000).values() for _ in range(50))))
with pytest.warns(
PennyLaneDeprecationWarning,
match="'shots' as an argument to the quantum function is deprecated",
):
counts, _ = tuple(
zip(*(qml.snapshots(circuit)(shots=1000).values() for _ in range(50)))
)
assert ttest_ind([count["0"] for count in counts], 500).pvalue >= 0.75

@pytest.mark.parametrize("diff_method", ["backprop", "adjoint"])
Expand Down Expand Up @@ -717,7 +729,11 @@ def circuit():
_compare_numpy_dicts(result, expected)

# Make sure shots are overridden correctly
result = circuit(shots=200)
with pytest.warns(
PennyLaneDeprecationWarning,
match="'shots' as an argument to the quantum function is deprecated",
):
result = circuit(shots=200)
finite_shot_result = result[0]
assert not np.allclose( # Since 200 does not have a factor of 3, we assert that there's no chance for finite-shot tape to reach 1/3 exactly here.
finite_shot_result,
Expand Down
12 changes: 7 additions & 5 deletions tests/test_qcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -2620,7 +2620,7 @@ def cut_circuit(x):
v = 0.319

temp_shots = 333
cut_res = cut_circuit(v, shots=temp_shots) # pylint: disable=unexpected-keyword-arg
cut_res = qml.set_shots(shots=temp_shots)(cut_circuit)(v)

assert cut_res.shape == (temp_shots, 2)

Expand Down Expand Up @@ -4538,8 +4538,9 @@ def circuit():
spy = mocker.spy(qcut.cutcircuit, "_qcut_expand_fn")
spy_mc = mocker.spy(qcut.cutcircuit_mc, "_qcut_expand_fn")

kwargs = {"shots": 10} if isinstance(measurement, qml.measurements.SampleMP) else {}
cut_transform(circuit, device_wires=[0])(**kwargs)
if isinstance(measurement, qml.measurements.SampleMP):
circuit = qml.set_shots(circuit, shots=10)
cut_transform(circuit, device_wires=[0])()

assert spy.call_count == 1 or spy_mc.call_count == 1

Expand All @@ -4555,8 +4556,9 @@ def circuit():
return qml.apply(measurement)

with pytest.raises(ValueError, match="No WireCut operations found in the circuit."):
kwargs = {"shots": 10} if isinstance(measurement, qml.measurements.SampleMP) else {}
cut_transform(circuit, device_wires=[0])(**kwargs)
if isinstance(measurement, qml.measurements.SampleMP):
circuit = qml.set_shots(circuit, shots=10)
cut_transform(circuit, device_wires=[0])()

def test_expansion_ttn(self, mocker):
"""Test if wire cutting is compatible with the tree tensor network operation"""
Expand Down
Loading
Loading