Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@
technically a callable that returns an `X` operator.
[(#8060)](https://github.com/PennyLaneAI/pennylane/pull/8060)

* With program capture, an error is now raised if the conditional predicate is not a scalar.
[(#8066)](https://github.com/PennyLaneAI/pennylane/pull/8066)

<h4>OpenQASM-PennyLane interoperability</h4>

* The :func:`qml.from_qasm3` function can now convert OpenQASM 3.0 circuits that contain
Expand Down
2 changes: 2 additions & 0 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def __call_capture_enabled(self, *args, **kwargs):
abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(args)

for pred, fn in branches:
if (pred_shape := qml.math.shape(pred)) != ():
raise ValueError(f"Condition predicate must be a scalar. Got {pred_shape}.")
conditions.append(pred)
if fn is None:
jaxpr_branches.append(None)
Expand Down
10 changes: 10 additions & 0 deletions tests/capture/test_capture_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def false_fn(arg):
return true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4


def test_bad_predicate_shape():
"""Test that an error is raised if the predicate is not a scalar."""

def f():
qml.cond(np.array([0, 0]), qml.X, qml.Z)(0)

with pytest.raises(ValueError, match="predicate must be a scalar"):
jax.make_jaxpr(f)()


@pytest.mark.parametrize("decorator", [True, False])
class TestCond:
"""Tests for conditional functions using qml.cond."""
Expand Down