Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@

<h3>Improvements 🛠</h3>

* The qnode primitive is plxpr program capture now captures the unprocessed `ExecutionConfig`, instead of
one processed by the device.

* The function :func:`qml.clifford_t_decomposition` with `method="gridsynth"` are now compatible
with quantum just-in-time compilation via the `@qml.qjit` decorator.
[(#7711)](https://github.com/PennyLaneAI/pennylane/pull/7711)
Expand Down
5 changes: 3 additions & 2 deletions pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=(
@qnode_prim.def_impl
def _(*args, qnode, device, execution_config, qfunc_jaxpr, n_consts, shots_len, batch_dims=None):

execution_config = device.setup_execution_config(execution_config)

if shots_len == 0:
shots = None
non_shots_args = args
Expand Down Expand Up @@ -354,7 +356,7 @@ def _finite_diff(args, tangents, **impl_kwargs):

@debug_logger
def _qnode_jvp(args, tangents, *, execution_config, device, qfunc_jaxpr, **impl_kwargs):

execution_config = device.setup_execution_config(execution_config)
if execution_config.use_device_gradient:
return device.jaxpr_jvp(qfunc_jaxpr, args, tangents, execution_config=execution_config)

Expand Down Expand Up @@ -567,7 +569,6 @@ def f(x):
config = construct_execution_config(
qnode, resolve=False
)() # no need for args and kwargs as not resolving
config = qnode.device.setup_execution_config(config)

if abstracted_axes:
# We unflatten the ``abstracted_axes`` here to be have the same pytree structure
Expand Down
12 changes: 6 additions & 6 deletions tests/capture/workflow/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def circuit(x):
assert eqn0.params["qnode"] == circuit
assert eqn0.params["shots_len"] == 0
expected_config = qml.devices.ExecutionConfig(
gradient_method="backprop",
use_device_gradient=True,
gradient_method="best",
use_device_gradient=None,
gradient_keyword_arguments={},
use_device_jacobian_product=False,
interface="jax",
grad_on_execution=False,
device_options={"max_workers": None, "rng": dev._rng, "prng_key": None},
mcm_config=qml.devices.MCMConfig(mcm_method="deferred", postselect_mode=None),
device_options={},
mcm_config=qml.devices.MCMConfig(mcm_method=None, postselect_mode=None),
)
assert eqn0.params["execution_config"] == expected_config

Expand Down Expand Up @@ -231,15 +231,15 @@ def circuit():
assert jaxpr.eqns[0].primitive == qnode_prim
expected_config = qml.devices.ExecutionConfig(
gradient_method="parameter-shift",
use_device_gradient=False,
use_device_gradient=None,
grad_on_execution=False,
derivative_order=2,
use_device_jacobian_product=False,
mcm_config=qml.devices.MCMConfig(
mcm_method="single-branch-statistics", postselect_mode=None
),
interface=qml.math.Interface.JAX,
device_options={"max_workers": None, "rng": dev._rng, "prng_key": None},
device_options={},
)
assert jaxpr.eqns[0].params["execution_config"] == expected_config

Expand Down
Loading