Skip to content

Commit 2bc8aa4

Browse files
albi3roandrijapau
andauthored
Qnode primitive captures unprocessed ExecutionConfig (#8258)
**Context:** Blocks PennyLaneAI/catalyst#2041 **Description of the Change:** The capture `Device.setup_execution_config` was about the capabilities of the capture execution pipeline, which is no longer maintained or actively developed. While I'm not sure where will be the best place for this logic in catalyst, we can leave that up to catalyst and no determine it in program capture. **Benefits:** Unblocks catalyst unification work. **Possible Drawbacks:** **Related GitHub Issues:** [sc-99193] --------- Co-authored-by: Andrija Paurevic <[email protected]>
1 parent c63b94c commit 2bc8aa4

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@
121121

122122
<h3>Improvements 🛠</h3>
123123

124+
* The `QNode` primitive in the experimental program capture now captures the unprocessed `ExecutionConfig`, instead of
125+
one processed by the device.
126+
[(#8258)](https://github.com/PennyLaneAI/pennylane/pull/8258)
127+
124128
* The function :func:`qml.clifford_t_decomposition` with `method="gridsynth"` are now compatible
125129
with quantum just-in-time compilation via the `@qml.qjit` decorator.
126130
[(#7711)](https://github.com/PennyLaneAI/pennylane/pull/7711)

pennylane/workflow/_capture_qnode.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=(
163163
@qnode_prim.def_impl
164164
def _(*args, qnode, device, execution_config, qfunc_jaxpr, n_consts, shots_len, batch_dims=None):
165165

166+
execution_config = device.setup_execution_config(execution_config)
167+
166168
if shots_len == 0:
167169
shots = None
168170
non_shots_args = args
@@ -354,7 +356,7 @@ def _finite_diff(args, tangents, **impl_kwargs):
354356

355357
@debug_logger
356358
def _qnode_jvp(args, tangents, *, execution_config, device, qfunc_jaxpr, **impl_kwargs):
357-
359+
execution_config = device.setup_execution_config(execution_config)
358360
if execution_config.use_device_gradient:
359361
return device.jaxpr_jvp(qfunc_jaxpr, args, tangents, execution_config=execution_config)
360362

@@ -567,7 +569,6 @@ def f(x):
567569
config = construct_execution_config(
568570
qnode, resolve=False
569571
)() # no need for args and kwargs as not resolving
570-
config = qnode.device.setup_execution_config(config)
571572

572573
if abstracted_axes:
573574
# We unflatten the ``abstracted_axes`` here to be have the same pytree structure

tests/capture/workflow/test_capture_qnode.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ def circuit(x):
101101
assert eqn0.params["qnode"] == circuit
102102
assert eqn0.params["shots_len"] == 0
103103
expected_config = qml.devices.ExecutionConfig(
104-
gradient_method="backprop",
105-
use_device_gradient=True,
104+
gradient_method="best",
105+
use_device_gradient=None,
106106
gradient_keyword_arguments={},
107107
use_device_jacobian_product=False,
108108
interface="jax",
109109
grad_on_execution=False,
110-
device_options={"max_workers": None, "rng": dev._rng, "prng_key": None},
111-
mcm_config=qml.devices.MCMConfig(mcm_method="deferred", postselect_mode=None),
110+
device_options={},
111+
mcm_config=qml.devices.MCMConfig(mcm_method=None, postselect_mode=None),
112112
)
113113
assert eqn0.params["execution_config"] == expected_config
114114

@@ -231,15 +231,15 @@ def circuit():
231231
assert jaxpr.eqns[0].primitive == qnode_prim
232232
expected_config = qml.devices.ExecutionConfig(
233233
gradient_method="parameter-shift",
234-
use_device_gradient=False,
234+
use_device_gradient=None,
235235
grad_on_execution=False,
236236
derivative_order=2,
237237
use_device_jacobian_product=False,
238238
mcm_config=qml.devices.MCMConfig(
239239
mcm_method="single-branch-statistics", postselect_mode=None
240240
),
241241
interface=qml.math.Interface.JAX,
242-
device_options={"max_workers": None, "rng": dev._rng, "prng_key": None},
242+
device_options={},
243243
)
244244
assert jaxpr.eqns[0].params["execution_config"] == expected_config
245245

0 commit comments

Comments
 (0)