Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,7 @@
when a gateset is provided. `default.qubit` and `null.qubit` can now use
graph decomposition mode.
[(#8225)](https://github.com/PennyLaneAI/pennylane/pull/8225)
[(#8260)](https://github.com/PennyLaneAI/pennylane/pull/8260)

* `DefaultQubit` now determines the `mcm_method` in `Device.setup_execution_config`,
making it easier to tell which mcm method will be used. This also allows `defer_measurements` and `dynamic_one_shot` to be applied at different
Expand Down
17 changes: 16 additions & 1 deletion pennylane/devices/null_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np

from pennylane import math
from pennylane.decomposition import enabled_graph, has_decomp
from pennylane.devices.modifiers import simulator_tracking, single_tape_support
from pennylane.measurements import (
ClassicalShadowMP,
Expand Down Expand Up @@ -354,7 +355,7 @@ def preprocess(
original_stopping_condition = t.kwargs["stopping_condition"]

def new_stopping_condition(op):
return (not op.has_decomposition) or original_stopping_condition(op)
return not _op_has_decomp(op) or original_stopping_condition(op)

t.kwargs["stopping_condition"] = new_stopping_condition

Expand Down Expand Up @@ -507,3 +508,17 @@ def zeros_like(var, shots):
return math.zeros(var.aval.shape, dtype=var.aval.dtype, like="jax")

return [zeros_like(var, Shots(shots).total_shots) for var in jaxpr.outvars]


def _op_has_decomp(op):
"""Check if an operator has a decomposition, taking into account the graph-based decomposition system.

Args:
op (Operator): The operator to check.

Returns:
bool: True if the operator has a decomposition, False otherwise.
"""
if enabled_graph():
return has_decomp(type(op))
return op.has_decomposition
44 changes: 44 additions & 0 deletions tests/devices/test_null_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def test_tf_backprop(self, style):
assert g1 == 0


@pytest.mark.usefixtures("enable_and_disable_graph_decomp")
@pytest.mark.parametrize("config", [None, ExecutionConfig(gradient_method="device")])
class TestDeviceDifferentiation:
"""Tests device differentiation integration with NullQubit."""
Expand Down Expand Up @@ -1397,3 +1398,46 @@ def f(x):
assert qml.math.allclose(res[0], 0)
assert qml.math.allclose(res[1], 0)
assert qml.math.allclose(res[2], jax.numpy.zeros((50, 2)))


class TestNullQubitGraphModeExclusive:
"""Tests for NullQubit features that require graph mode enabled.
The legacy decomposition mode should not be able to run these tests.
NOTE: All tests in this suite will auto-enable graph mode via fixture.
"""

@pytest.fixture(autouse=True)
def enable_graph_mode_only(self):
"""Auto-enable graph mode for all tests in this class."""
try:
qml.decomposition.enable_graph()
yield
finally:
qml.decomposition.disable_graph()

def test_insufficient_work_wires_causes_fallback(self):
"""Test that if a decomposition requires more work wires than available on null.qubit,
that decomposition is discarded and fallback is used."""

class MyNullQubitOp(qml.operation.Operator): # pylint: disable=too-few-public-methods
num_wires = 1

@qml.register_resources({qml.H: 2})
def decomp_fallback(wires):
qml.H(wires)
qml.H(wires)

@qml.register_resources({qml.X: 1}, work_wires={"burnable": 5})
def decomp_with_work_wire(wires):
qml.X(wires)

qml.add_decomps(MyNullQubitOp, decomp_fallback, decomp_with_work_wire)

tape = qml.tape.QuantumScript([MyNullQubitOp(0)])
dev = qml.device("null.qubit", wires=1) # Only 1 wire, but decomp needs 5 burnable
program = dev.preprocess_transforms()
(out_tape,), _ = program([tape])

assert len(out_tape.operations) == 2
assert out_tape.operations[0].name == "Hadamard"
assert out_tape.operations[1].name == "Hadamard"
Loading