Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@

<h3>Improvements 🛠</h3>

* The `QNode` primitive in the experimental program capture now captures the unprocessed `ExecutionConfig`, instead of
one processed by the device.
[(#8258)](https://github.com/PennyLaneAI/pennylane/pull/8258)

* The function :func:`qml.clifford_t_decomposition` with `method="gridsynth"` are now compatible
with quantum just-in-time compilation via the `@qml.qjit` decorator.
[(#7711)](https://github.com/PennyLaneAI/pennylane/pull/7711)
Expand Down
5 changes: 3 additions & 2 deletions pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=(
@qnode_prim.def_impl
def _(*args, qnode, device, execution_config, qfunc_jaxpr, n_consts, shots_len, batch_dims=None):

execution_config = device.setup_execution_config(execution_config)

if shots_len == 0:
shots = None
non_shots_args = args
Expand Down Expand Up @@ -354,7 +356,7 @@ def _finite_diff(args, tangents, **impl_kwargs):

@debug_logger
def _qnode_jvp(args, tangents, *, execution_config, device, qfunc_jaxpr, **impl_kwargs):

execution_config = device.setup_execution_config(execution_config)
if execution_config.use_device_gradient:
return device.jaxpr_jvp(qfunc_jaxpr, args, tangents, execution_config=execution_config)

Expand Down Expand Up @@ -567,7 +569,6 @@ def f(x):
config = construct_execution_config(
qnode, resolve=False
)() # no need for args and kwargs as not resolving
config = qnode.device.setup_execution_config(config)

if abstracted_axes:
# We unflatten the ``abstracted_axes`` here to be have the same pytree structure
Expand Down
12 changes: 6 additions & 6 deletions tests/capture/workflow/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def circuit(x):
assert eqn0.params["qnode"] == circuit
assert eqn0.params["shots_len"] == 0
expected_config = qml.devices.ExecutionConfig(
gradient_method="backprop",
use_device_gradient=True,
gradient_method="best",
use_device_gradient=None,
gradient_keyword_arguments={},
use_device_jacobian_product=False,
interface="jax",
grad_on_execution=False,
device_options={"max_workers": None, "rng": dev._rng, "prng_key": None},
mcm_config=qml.devices.MCMConfig(mcm_method="deferred", postselect_mode=None),
device_options={},
mcm_config=qml.devices.MCMConfig(mcm_method=None, postselect_mode=None),
)
assert eqn0.params["execution_config"] == expected_config

Expand Down Expand Up @@ -231,15 +231,15 @@ def circuit():
assert jaxpr.eqns[0].primitive == qnode_prim
expected_config = qml.devices.ExecutionConfig(
gradient_method="parameter-shift",
use_device_gradient=False,
use_device_gradient=None,
grad_on_execution=False,
derivative_order=2,
use_device_jacobian_product=False,
mcm_config=qml.devices.MCMConfig(
mcm_method="single-branch-statistics", postselect_mode=None
),
interface=qml.math.Interface.JAX,
device_options={"max_workers": None, "rng": dev._rng, "prng_key": None},
device_options={},
)
assert jaxpr.eqns[0].params["execution_config"] == expected_config

Expand Down