Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7350f86
Enhance preprocess.decompose with device_wires and target_gates support
JerryChen97 Sep 9, 2025
c08ef4f
log
JerryChen97 Sep 9, 2025
d91ef83
Update doc/releases/changelog-dev.md
JerryChen97 Sep 9, 2025
79bd25e
Update doc/releases/changelog-dev.md
JerryChen97 Sep 10, 2025
3034399
rename: preprocess_decompose -> decompose
JerryChen97 Sep 10, 2025
bed5f6d
Revert changes to the capture
JerryChen97 Sep 10, 2025
3afe639
Merge branch 'master' into new-decomp-integration/preprocess-clean
JerryChen97 Sep 10, 2025
fa478e8
meaningless comment
JerryChen97 Sep 10, 2025
eede4c5
rtd fix?
JerryChen97 Sep 10, 2025
31c2600
Merge branch 'master' into new-decomp-integration/preprocess-clean
JerryChen97 Sep 10, 2025
6ed5df4
add exclusive tests
JerryChen97 Sep 10, 2025
2b394cb
Merge branch 'master' into new-decomp-integration/preprocess-clean
JerryChen97 Sep 10, 2025
3dca21c
clean up
JerryChen97 Sep 11, 2025
09ed41d
Merge branch 'master' into new-decomp-integration/preprocess-clean
JerryChen97 Sep 11, 2025
f9e576c
draft
JerryChen97 Sep 11, 2025
b063b9f
Merge branch 'master' into new-decomp-integration/default-tensor
JerryChen97 Sep 11, 2025
5896a1d
fix dq
JerryChen97 Sep 11, 2025
29d9219
fix2
JerryChen97 Sep 11, 2025
55218c6
fix3
JerryChen97 Sep 11, 2025
7109cd1
log
JerryChen97 Sep 11, 2025
c9720af
tests for NQ
JerryChen97 Sep 11, 2025
2dfaa7c
trim, combine tests
JerryChen97 Sep 11, 2025
ab47e16
fmt
JerryChen97 Sep 12, 2025
d409dc2
Merge branch 'master' into new-decomp-integration/default-tensor
JerryChen97 Sep 12, 2025
a37ba6d
disable pylint
JerryChen97 Sep 12, 2025
363ca06
Merge branch 'master' into new-decomp-integration/default-tensor
JerryChen97 Sep 12, 2025
c32a20f
revert nq test
JerryChen97 Sep 12, 2025
9e58505
specify dual mode tests
JerryChen97 Sep 12, 2025
d992fba
improve tensor graph tests
JerryChen97 Sep 12, 2025
a70f8ff
clean outdated testing remains
JerryChen97 Sep 17, 2025
3950394
Merge branch 'master' into new-decomp-integration/default-tensor
JerryChen97 Sep 17, 2025
a38a4be
Merge branch 'master' into new-decomp-integration/default-tensor
JerryChen97 Sep 17, 2025
fd31df6
Merge branch 'master' into new-decomp-integration/default-tensor
JerryChen97 Sep 17, 2025
925375b
Update tests/devices/default_tensor/test_default_tensor.py
JerryChen97 Sep 17, 2025
1a51459
Merge branch 'master' into new-decomp-integration/default-tensor
JerryChen97 Sep 17, 2025
d5c43f3
disable pylint
JerryChen97 Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,9 @@

<h3>Internal changes ⚙️</h3>

* `default.tensor` now supports graph decomposition mode during preprocessing.
[(#8253)](https://github.com/PennyLaneAI/pennylane/pull/8253)

* Remove legacy interface names from tests (e.g. `interface="jax-python"` or `interface="pytorch"`)
[(#8249)](https://github.com/PennyLaneAI/pennylane/pull/8249)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/devices/default_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,8 @@ def preprocess(
stopping_condition=stopping_condition,
skip_initial_state_prep=True,
name=self.name,
device_wires=self.wires,
target_gates=_operations,
)
program.add_transform(qml.transforms.broadcast_expand)

Expand Down
152 changes: 152 additions & 0 deletions tests/devices/default_tensor/test_default_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from scipy.sparse import csr_matrix

import pennylane as qml
from pennylane.devices import ExecutionConfig
from pennylane.devices.default_tensor import _operations, stopping_condition
from pennylane.devices.preprocess import decompose
from pennylane.exceptions import DeviceError, WireError
from pennylane.math.decomposition import givens_decomposition
from pennylane.typing import TensorLike
Expand Down Expand Up @@ -412,6 +415,7 @@ def test_execute_and_compute_vjp(self):
dev.execute_and_compute_vjp(circuits=None, cotangents=None)


@pytest.mark.usefixtures("enable_and_disable_graph_decomp")
@pytest.mark.parametrize("method", ["mps", "tn"])
@pytest.mark.jax
class TestJaxSupport:
Expand Down Expand Up @@ -450,6 +454,7 @@ def circuit():
assert np.allclose(circuit(), 0.0)


@pytest.mark.usefixtures("enable_and_disable_graph_decomp")
@pytest.mark.parametrize("method", ["mps", "tn"])
@pytest.mark.parametrize(
"operation, expected_output, par",
Expand Down Expand Up @@ -491,6 +496,7 @@ def circuit():


# At this stage, this test is especially relevant for the MPS method, but we test both methods for consistency.
@pytest.mark.usefixtures("enable_and_disable_graph_decomp")
@pytest.mark.parametrize("num_orbitals", [2, 4])
@pytest.mark.parametrize("method", ["mps", "tn"])
def test_wire_order_dense_vector(method, num_orbitals):
Expand Down Expand Up @@ -528,6 +534,7 @@ def circuit():
assert len(state) == 2 ** (2 * num_orbitals + 1)


@pytest.mark.usefixtures("enable_and_disable_graph_decomp")
class TestMCMs:
"""Test that default.tensor can handle mid circuit measurements."""

Expand Down Expand Up @@ -566,3 +573,148 @@ def circuit(x):

res = circuit(0.5)
assert qml.math.allclose(res, np.cos(0.5))


@pytest.mark.usefixtures("enable_and_disable_graph_decomp")
class TestPreprocessingTransforms:
"""Tests for the preprocessing transform pipeline."""

def test_preprocess_transforms_structure(self):
"""Test that the preprocessing transforms are set up correctly."""
dev = qml.device("default.tensor", wires=3)
config = ExecutionConfig()

program, _ = dev.preprocess(config)

# Check that we have the expected transforms
transform_names = [
transform_container.transform.__name__ for transform_container in program
]
expected_transforms = [
"validate_measurements",
"validate_observables",
"validate_device_wires",
"defer_measurements",
"decompose",
"broadcast_expand",
]

for expected_transform in expected_transforms:
assert expected_transform in transform_names

def test_decompose_transform_has_device_wires_and_target_gates(self):
"""Test that the decompose transform is configured with device_wires and target_gates."""
dev = qml.device("default.tensor", wires=[0, 1, 2])
config = ExecutionConfig()

program, _ = dev.preprocess(config)

# Find the decompose transform
decompose_transform = None
for transform_container in program:
if transform_container.transform.__name__ == "decompose":
decompose_transform = transform_container
break

assert decompose_transform is not None

# Check that device_wires and target_gates are passed correctly
assert "device_wires" in decompose_transform.kwargs
assert "target_gates" in decompose_transform.kwargs
assert decompose_transform.kwargs["device_wires"] == dev.wires
assert decompose_transform.kwargs["target_gates"] == _operations

def test_decompose_with_stopping_condition(self):
"""Test that decompose transform uses the correct stopping condition."""
dev = qml.device("default.tensor", wires=3)
config = ExecutionConfig()

program, _ = dev.preprocess(config)

# Find the decompose transform
decompose_transform = None
for transform_container in program:
if transform_container.transform.__name__ == "decompose":
decompose_transform = transform_container
break

assert decompose_transform is not None
assert "stopping_condition" in decompose_transform.kwargs
assert decompose_transform.kwargs["stopping_condition"] == stopping_condition

@pytest.mark.integration
def test_integration_with_qnode(self):
"""Test integration with QNode to ensure the device works end-to-end."""
dev = qml.device("default.tensor", wires=3)

@qml.qnode(dev)
def circuit():
# Use an operation that needs decomposition
qml.QFT(wires=[0, 1])
return qml.expval(qml.Z(0))

# This should work without errors
result = circuit()
assert isinstance(result, (float, np.floating))

@pytest.mark.integration
def test_integration_with_multiple_decomposition_layers(self):
"""Test that operations requiring multiple layers of decomposition work."""
dev = qml.device("default.tensor", wires=4)

@qml.qnode(dev)
def circuit():
# Operations that may require multiple decomposition steps
qml.QFT(wires=[0, 1, 2])
qml.GroverOperator(wires=[0, 1, 2, 3])
return qml.expval(qml.Z(0))

# This should work without errors
result = circuit()
assert isinstance(result, (float, np.floating))


class TestGraphModeExclusiveFeatures:
"""Tests that only work when graph mode is enabled."""

@pytest.fixture(autouse=True)
def enable_graph_mode_only(self):
"""Auto-enable graph mode for all tests in this class."""
qml.decomposition.enable_graph()
yield
qml.decomposition.disable_graph()

def test_work_wire_constraint_respected(self):
"""Test that decompositions requiring more work wires than available are discarded."""

# Create a mock operation with different decomposition options
class MyOp(qml.operation.Operator): # pylint: disable=too-few-public-methods
num_wires = 1

# Fallback decomposition (no work wires needed)
@qml.register_resources({qml.Hadamard: 2})
def decomp_fallback(wires):
qml.Hadamard(wires)
qml.Hadamard(wires)

# Work wire decomposition (needs more wires than available)
@qml.register_resources({qml.PauliX: 1}, work_wires={"burnable": 3})
def decomp_with_work_wire(wires):
qml.PauliX(wires)

qml.add_decomps(MyOp, decomp_fallback, decomp_with_work_wire)

tape = qml.tape.QuantumScript([MyOp(0)], [qml.expval(qml.Z(0))])
device_wires = qml.wires.Wires([0, 1]) # Only 2 wires, insufficient for 3 burnable
target_gates = {"Hadamard", "PauliX"}

(out_tape,), _ = decompose(
tape,
lambda obj: obj.name in target_gates,
device_wires=device_wires,
target_gates=target_gates,
)

# Should use fallback decomposition (2 Hadamards) due to work wire constraint
assert len(out_tape.operations) == 2
assert all(op.name == "Hadamard" for op in out_tape.operations)
49 changes: 49 additions & 0 deletions tests/devices/test_null_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def check_outputs():
check_outputs()


@pytest.mark.usefixtures("enable_and_disable_graph_decomp")
@pytest.mark.parametrize("shots", (None, 10))
def test_supports_operator_without_decomp(shots):
"""Test that null.qubit automatically supports any operation without a decomposition."""
Expand Down Expand Up @@ -831,6 +832,7 @@ def test_tf_backprop(self, style):
assert g1 == 0


@pytest.mark.usefixtures("enable_and_disable_graph_decomp")
@pytest.mark.parametrize("config", [None, ExecutionConfig(gradient_method="device")])
class TestDeviceDifferentiation:
"""Tests device differentiation integration with NullQubit."""
Expand Down Expand Up @@ -1397,3 +1399,50 @@ def f(x):
assert qml.math.allclose(res[0], 0)
assert qml.math.allclose(res[1], 0)
assert qml.math.allclose(res[2], jax.numpy.zeros((50, 2)))


class TestNullQubitGraphModeExclusive:
"""Tests for NullQubit features that require graph mode enabled.
The legacy decomposition mode should not be able to run these tests.

NOTE: All tests in this suite will auto-enable graph mode via fixture.
"""

@pytest.fixture(autouse=True)
def enable_graph_mode_only(self):
"""Auto-enable graph mode for all tests in this class."""
qml.decomposition.enable_graph()
yield
qml.decomposition.disable_graph()

def test_insufficient_work_wires_causes_fallback(self):
"""Test that if a decomposition requires more work wires than available on null.qubit,
that decomposition is discarded and fallback is used."""

class MyNullQubitOp(qml.operation.Operator): # pylint: disable=too-few-public-methods
num_wires = 1

def decomposition(
self,
): # !Note: This is crucial since otherwise it will be stopped by NQ
return NotImplemented

@qml.register_resources({qml.H: 2})
def decomp_fallback(wires):
qml.H(wires)
qml.H(wires)

@qml.register_resources({qml.X: 1}, work_wires={"burnable": 5})
def decomp_with_work_wire(wires):
qml.X(wires)

qml.add_decomps(MyNullQubitOp, decomp_fallback, decomp_with_work_wire)

tape = qml.tape.QuantumScript([MyNullQubitOp(0)])
dev = qml.device("null.qubit", wires=1) # Only 1 wire, but decomp needs 5 burnable
program = dev.preprocess_transforms()
(out_tape,), _ = program([tape])

assert len(out_tape.operations) == 2
assert out_tape.operations[0].name == "Hadamard"
assert out_tape.operations[1].name == "Hadamard"
Loading