Skip to content

Commit 62b73ce

Browse files
matrix and cond now dequeues operators in arguments to the qfunc (#8119)
**Context:** Applying the same thing in #8094 to `matrix` and `cond` **Description of the Change:** **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-98196] --------- Co-authored-by: Yushao Chen (Jerry) <[email protected]>
1 parent 10747ae commit 62b73ce

File tree

5 files changed

+58
-3
lines changed

5 files changed

+58
-3
lines changed

doc/releases/changelog-dev.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,8 +871,9 @@
871871
* Fixes `SemiAdder` to work when inputs are defined with a single wire.
872872
[(#7940)](https://github.com/PennyLaneAI/pennylane/pull/7940)
873873

874-
* Fixes a bug where `qml.prod` applied on a quantum function does not dequeue operators passed as arguments to the function.
874+
* Fixes a bug where `qml.prod`, `qml.matrix`, and `qml.cond` applied on a quantum function does not dequeue operators passed as arguments to the function.
875875
[(#8094)](https://github.com/PennyLaneAI/pennylane/pull/8094)
876+
[(#8119)](https://github.com/PennyLaneAI/pennylane/pull/8119)
876877

877878
<h3>Contributors ✍️</h3>
878879

pennylane/ops/functions/matrix.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def circuit():
230230
f"Wires in circuit {list(op.wires)} are inconsistent with "
231231
f"those in wire_order {list(wire_order)}"
232232
)
233+
QueuingManager.remove(op)
233234
if op.has_matrix:
234235
return op.matrix(wire_order=wire_order)
235236
if op.has_sparse_matrix:
@@ -263,6 +264,9 @@ def processing_fn(res):
263264
params = res[0].get_parameters(trainable_only=False)
264265
interface = qml.math.get_interface(*params)
265266

267+
for op in res[0].operations:
268+
QueuingManager.remove(op)
269+
266270
# initialize the unitary matrix
267271
if len(res[0].operations) == 0:
268272
result = qml.math.eye(2 ** len(wire_order), like=interface)

pennylane/ops/op_math/condition.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,18 @@ def elifs(self):
253253
return list(zip(self.preds[1:], self.branch_fns[1:]))
254254

255255
def __call_capture_disabled(self, *args, **kwargs):
256+
257+
# dequeue operators passed to args
258+
leaves, _ = qml.pytrees.flatten((args, kwargs), lambda obj: isinstance(obj, Operator))
259+
for l in leaves:
260+
if isinstance(l, Operator):
261+
qml.QueuingManager.remove(l)
262+
256263
# python fallback
257264
for pred, branch_fn in zip(self.preds, self.branch_fns):
258265
if pred:
259266
return branch_fn(*args, **kwargs)
260-
# TODO: Remove when PL supports pylint==3.3.6 (it is considered a useless-suppression) [sc-91362]
261-
# pylint: disable=not-callable
267+
262268
return self.false_fn(*args, **kwargs)
263269

264270
def __call_capture_enabled(self, *args, **kwargs):

tests/ops/functions/test_matrix.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,16 @@ def node():
238238
]
239239
assert all(np.allclose(mat, np.eye(4)) for mat in mats)
240240

241+
def test_matrix_dequeues_operation(self):
242+
"""Tests that the operator is dequeued."""
243+
244+
with qml.queuing.AnnotatedQueue() as q:
245+
mat = qml.matrix(qml.X(0))
246+
qml.QubitUnitary(mat, wires=[0])
247+
248+
assert len(q.queue) == 1
249+
assert isinstance(q.queue[0], qml.QubitUnitary)
250+
241251

242252
class TestMultipleOperations:
243253
def test_multiple_operations_tape(self):
@@ -274,6 +284,21 @@ def testcircuit():
274284
expected_matrix = I_CNOT @ X_S_H
275285
assert np.allclose(matrix, expected_matrix)
276286

287+
def test_qfunc_arguments_dequeued(self):
288+
"""Tests that operators passed as arguments to the qfunc are dequeued"""
289+
290+
def func(op, op1=None):
291+
qml.apply(op)
292+
if op1:
293+
qml.apply(op1)
294+
295+
with qml.queuing.AnnotatedQueue() as q:
296+
mat = qml.matrix(func, wire_order=[0])(qml.X(0), op1=qml.Z(0))
297+
qml.QubitUnitary(mat, wires=[0])
298+
299+
assert len(q.queue) == 1
300+
assert isinstance(q.queue[0], qml.QubitUnitary)
301+
277302
def test_multiple_operations_qnode(self):
278303
"""Check the total matrix for a QNode containing multiple gates"""
279304
dev = qml.device("default.qubit", wires=["a", "b", "c"])

tests/ops/op_math/test_condition.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,25 @@ def test_data_set_correctly(self, op):
251251
assert cond_op.num_params == op.num_params
252252
assert cond_op.ndim_params == op.ndim_params
253253

254+
def test_qfunc_arg_dequeued(self):
255+
"""Tests that the operators in the quantum function arguments are dequeued."""
256+
257+
def true_fn(op):
258+
qml.apply(op)
259+
260+
def false_fn(op):
261+
qml.apply(op)
262+
263+
def circuit(x):
264+
qml.cond(x > 0, true_fn, false_fn)(qml.X(0))
265+
266+
with qml.queuing.AnnotatedQueue() as q:
267+
circuit(1)
268+
circuit(-1)
269+
270+
assert len(q.queue) == 2
271+
assert q.queue == [qml.X(0), qml.X(0)]
272+
254273

255274
@pytest.mark.parametrize("op_class", [qml.PauliY, qml.Toffoli, qml.Hadamard, qml.CZ])
256275
def test_conditional_label(op_class):

0 commit comments

Comments
 (0)