Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@

<h3>Improvements 🛠</h3>

* PennyLane `autograph` supports standard python for index assignment (`arr[i] = x`) instead of jax.numpy form (`arr = arr.at[i].set(x)`).
Users can now use standard python assignment when designing circuits with experimental program capture enabled.

```python
import pennylane as qml
import jax.numpy as jnp

qml.capture.enable()

@qml.qnode(qml.device("default.qubit", wires=3))
def circuit(val):
angles = jnp.zeros(3)
angles[1] = val / 2
angles[2] = val

for i, angle in enumerate(angles):
qml.RX(angle, i)

return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)), qml.expval(qml.Z(2))
```

```pycon
>>> circuit(jnp.pi)
(Array(0.99999994, dtype=float32),
Array(0., dtype=float32),
Array(-0.99999994, dtype=float32))
```

[(#8027)](https://github.com/PennyLaneAI/pennylane/pull/8027)

* Logical operations (`and`, `or` and `not`) are now supported with the `autograph` module. Users can
now use these logical operations in control flow when designing quantum circuits with experimental
program capture enabled.
Expand Down
22 changes: 21 additions & 1 deletion pennylane/capture/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,27 @@
has_jax = False


__all__ = ["if_stmt", "for_stmt", "while_stmt", "converted_call", "and_", "or_", "not_"]
__all__ = [
"if_stmt",
"for_stmt",
"while_stmt",
"converted_call",
"and_",
"or_",
"not_",
"set_item",
]


def set_item(target, i, x):
"""An implementation of the AutoGraph 'set_item' function."""

if qml.math.is_abstract(target):
target = target.at[i].set(x)
else:
target[i] = x

return target


def _assert_results(results, var_names):
Expand Down
7 changes: 5 additions & 2 deletions pennylane/capture/autograph/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,18 +290,21 @@ def else_body():
)


# converter.Feature.LISTS permits overloading the 'set_item' function in 'ag_primitives.py'
OPTIONAL_FEATURES = [converter.Feature.BUILTIN_FUNCTIONS, converter.Feature.LISTS]

TOPLEVEL_OPTIONS = converter.ConversionOptions(
recursive=True,
user_requested=True,
internal_convert_user_code=True,
optional_features=[converter.Feature.BUILTIN_FUNCTIONS],
optional_features=OPTIONAL_FEATURES,
)

NESTED_OPTIONS = converter.ConversionOptions(
recursive=True,
user_requested=False,
internal_convert_user_code=True,
optional_features=[converter.Feature.BUILTIN_FUNCTIONS],
optional_features=OPTIONAL_FEATURES,
)

STANDARD_OPTIONS = converter.STANDARD_OPTIONS
Expand Down
266 changes: 266 additions & 0 deletions tests/capture/autograph/test_autograph_item_assignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test autograph support for standard Python item assignment with JAX Arrays."""

import pytest

pytestmark = pytest.mark.capture
jax = pytest.importorskip("jax")

# pylint: disable = wrong-import-position
import jax.numpy as jnp
from jax import make_jaxpr
from jax.core import eval_jaxpr

import pennylane as qml
from pennylane.capture.autograph import run_autograph


@pytest.mark.usefixtures("enable_disable_plxpr")
@pytest.mark.parametrize(
"array_in, index, new_value, array_out",
[
(jnp.array([1, 2, 3]), 0, 10, jnp.array([10, 2, 3])),
(jnp.array([1, 2, 3]), -1, 20, jnp.array([1, 2, 20])),
],
)
def test_single_integer_indexing(array_in, index, new_value, array_out):
"""Tests single integer indexing like `x[index] = new_value`."""

def fn(x):
x[index] = new_value
return x

ag_fn = run_autograph(fn)
args = (array_in,)
ag_fn_jaxpr = make_jaxpr(ag_fn)(*args)
result = eval_jaxpr(ag_fn_jaxpr.jaxpr, ag_fn_jaxpr.consts, *args)
assert jnp.array_equal(result[0], array_out)


@pytest.mark.usefixtures("enable_disable_plxpr")
@pytest.mark.parametrize(
"array_in, index, new_value, array_out",
[
(jnp.array([1, 2, 3]), slice(0, 2), 10, jnp.array([10, 10, 3])),
(jnp.array([1, 2, 3]), slice(1, None), 20, jnp.array([1, 20, 20])),
(jnp.array([1, 2, 3]), slice(None), 5, jnp.array([5, 5, 5])),
(jnp.array([1, 2, 3, 4, 5]), slice(0, None, 2), 9, jnp.array([9, 2, 9, 4, 9])),
],
)
def test_slicing(array_in, index, new_value, array_out):
"""Tests slicing assignment like `x[slice] = new_value`."""

def fn(x):
x[index] = new_value
return x

ag_fn = run_autograph(fn)
args = (array_in,)
ag_fn_jaxpr = make_jaxpr(ag_fn)(*args)
result = eval_jaxpr(ag_fn_jaxpr.jaxpr, ag_fn_jaxpr.consts, *args)
assert jnp.array_equal(result[0], array_out)


@pytest.mark.usefixtures("enable_disable_plxpr")
@pytest.mark.parametrize(
"array_in, index, new_value, array_out",
[
# Slice and set to non singleton value
(jnp.array([1, 2, 3, 4]), slice(1, 3), jnp.array([99, 88]), jnp.array([1, 99, 88, 4])),
# Use array for indexing
(jnp.array([1, 2, 3, 4]), jnp.array([0, 3]), 7, jnp.array([7, 2, 3, 7])),
# Use boolean mask
(
jnp.array([1, 5, 2, 6]),
jnp.array([False, True, False, True]),
0,
jnp.array([1, 0, 2, 0]),
),
# Index a two dimensional array
(jnp.array([[1, 2], [3, 4]]), 0, 9, jnp.array([[9, 9], [3, 4]])),
# Index with tuple
(jnp.array([[1, 2], [3, 4]]), (1, 0), 9, jnp.array([[1, 2], [9, 4]])),
# 3D array assignment
(
jnp.zeros((2, 2, 2)),
(0, 1, 0),
5.0,
jnp.array([[[0.0, 0.0], [5.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]),
),
# Ellipsis to select the last column
(
jnp.ones((3, 4)),
(..., -1),
99.0,
jnp.array([[1.0, 1.0, 1.0, 99.0], [1.0, 1.0, 1.0, 99.0], [1.0, 1.0, 1.0, 99.0]]),
),
# Complex numbers
(
jnp.array([1 + 1j, 2 + 2j]),
0,
3 - 3j,
jnp.array([3 - 3j, 2 + 2j]),
),
],
)
def test_non_trivial_indexing(array_in, index, new_value, array_out):
"""Tests non-trivial indexing like boolean masks or arrays."""

def fn(x):
x[index] = new_value
return x

ag_fn = run_autograph(fn)
args = (array_in,)
ag_fn_jaxpr = make_jaxpr(ag_fn)(*args)
result = eval_jaxpr(ag_fn_jaxpr.jaxpr, ag_fn_jaxpr.consts, *args)
assert jnp.array_equal(result[0], array_out)


@pytest.mark.usefixtures("enable_disable_plxpr")
def test_non_tracing_assignment():
"""Tests item assignment if the list is not a tracer."""

def fn():
x = [0] * 5
x[2] = 1
return x

ag_fn = run_autograph(fn)
ag_fn_jaxpr = make_jaxpr(ag_fn)()
result = eval_jaxpr(ag_fn_jaxpr.jaxpr, ag_fn_jaxpr.consts)
expected = jnp.array([0, 0, 1, 0, 0])
assert jnp.array_equal(result, expected)


@pytest.mark.usefixtures("enable_disable_plxpr")
def test_while_loop_integration():
"""Tests item assignment within a while loop."""

def fn():
x = jnp.zeros(5)
i = 0
while i < 5:
x[i] = i
i += 1
return x

ag_fn = run_autograph(fn)
ag_fn_jaxpr = make_jaxpr(ag_fn)()
result = eval_jaxpr(ag_fn_jaxpr.jaxpr, ag_fn_jaxpr.consts)
expected = jnp.array([0, 1, 2, 3, 4])
assert jnp.array_equal(result[0], expected)


@pytest.mark.usefixtures("enable_disable_plxpr")
def test_for_loop_integration():
"""Tests item assignment within a for loop."""

def fn():
x = jnp.zeros(5)
for i in range(5):
x[i] = i
return x

ag_fn = run_autograph(fn)
ag_fn_jaxpr = make_jaxpr(ag_fn)()
result = eval_jaxpr(ag_fn_jaxpr.jaxpr, ag_fn_jaxpr.consts)
expected = jnp.array([0, 1, 2, 3, 4])
assert jnp.array_equal(result[0], expected)


@pytest.mark.usefixtures("enable_disable_plxpr")
def test_qnode_with_python_array_assignment():
"""Test a QNode where a python array argument is modified."""

dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev)
def circuit(new_val):
angles = [0.1, 0.2, 0.3]
angles[0] = new_val
qml.RX(angles[0], wires=0)
return qml.expval(qml.Z(0))

ag_circuit = run_autograph(circuit)
new_angle = jnp.pi

# Test forward pass
res = ag_circuit(new_angle)
assert jnp.allclose(res, -1.0)

# Test gradient
grad = jax.grad(ag_circuit, argnums=0)(new_angle)
# d/dx cos(x) = -sin(x), at x=pi, -sin(pi) = 0
assert jnp.allclose(grad, 0.0)


@pytest.mark.usefixtures("enable_disable_plxpr")
def test_qnode_with_jax_array_assignment():
"""Test a QNode where a JAX array argument is modified."""

dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev)
def circuit(angles, new_val):
angles[0] = new_val
qml.RX(angles[0], wires=0)
return qml.expval(qml.Z(0))

ag_circuit = run_autograph(circuit)
angles_in = jnp.array([0.1, 0.2, 0.3])
new_angle = jnp.pi

# Test forward pass
res = ag_circuit(angles_in, new_angle)
assert jnp.allclose(res, -1.0)

# Test gradient
grad = jax.grad(ag_circuit, argnums=1)(angles_in, new_angle)
# d/dx cos(x) = -sin(x), at x=pi, -sin(pi) = 0
assert jnp.allclose(grad, 0.0)


@pytest.mark.usefixtures("enable_disable_plxpr")
def test_item_assignment_is_differentiable():
"""Test that item assignment is differentiable."""

def fn(x, val):
x[0] = val
return jnp.sum(x)

ag_fn = run_autograph(fn)
array_in = jnp.ones(5)
value_in = 5.0
args = (array_in, value_in)
grad_jaxpr = make_jaxpr(jax.grad(ag_fn, argnums=1))(*args)
result = eval_jaxpr(grad_jaxpr.jaxpr, grad_jaxpr.consts, *args)

assert jnp.allclose(result[0], 1.0)


@pytest.mark.usefixtures("enable_disable_plxpr")
def test_shape_mismatch_raises_error():
"""Test that assigning an array of the wrong shape raises an error."""

def fn(x):
x[0:2] = jnp.array([1, 2, 3])
return x

ag_fn = run_autograph(fn)
array_in = jnp.zeros(5)
with pytest.raises(ValueError, match="Incompatible shapes"):
_ = make_jaxpr(ag_fn)(array_in)
2 changes: 1 addition & 1 deletion tests/capture/test_capture_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def true_fn_2(arg):
def false_fn_2(arg):
return qml.RY(0.1, wires=0)

[dyn_pred_2, _] = qml.cond(dyn_pred_1, true_fn_1, false_fn_1, elifs=())(arg)
dyn_pred_2, _ = qml.cond(dyn_pred_1, true_fn_1, false_fn_1, elifs=())(arg)
qml.cond(dyn_pred_2, true_fn_2, false_fn_2, elifs=())(arg)
return qml.expval(qml.Z(0))

Expand Down