Skip to content

Commit bd07c01

Browse files
fixing 1 qubit-size SemiAdder (#7940)
**Context:** the SemiAdder was not working in the edge case of x + y defined with a single qubit. This PR covers this situation **Description of the Change:** In this scenario, the semiAdder is a CNOT. Note that in this case work_wires are not needed. So I had to change the logic that was assuming we always have work wires **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** --------- Co-authored-by: Yushao Chen (Jerry) <[email protected]>
1 parent e7b214d commit bd07c01

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,14 @@
321321
This allows for types to be inferred correctly when parsing.
322322
[(#7825)](https://github.com/PennyLaneAI/pennylane/pull/7825)
323323

324+
* Fixes `SemiAdder` to work when inputs are defined with a single wire.
325+
[(#7940)](https://github.com/PennyLaneAI/pennylane/pull/7940)
326+
324327
<h3>Contributors ✍️</h3>
325328

326329
This release contains contributions from (in alphabetical order):
327330

331+
Guillermo Alonso,
328332
Utkarsh Azad,
329333
Joey Carter,
330334
Yushao Chen,

pennylane/templates/subroutines/semi_adder.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,22 +140,25 @@ def __init__(
140140

141141
x_wires = qml.wires.Wires(x_wires)
142142
y_wires = qml.wires.Wires(y_wires)
143-
work_wires = qml.wires.Wires(work_wires)
144-
145-
if len(work_wires) < len(y_wires) - 1:
146-
raise ValueError(f"At least {len(y_wires)-1} work_wires should be provided.")
147-
if work_wires.intersection(x_wires):
148-
raise ValueError("None of the wires in work_wires should be included in x_wires.")
149-
if work_wires.intersection(y_wires):
150-
raise ValueError("None of the wires in work_wires should be included in y_wires.")
143+
if work_wires:
144+
work_wires = qml.wires.Wires(work_wires)
145+
if len(work_wires) < len(y_wires) - 1:
146+
raise ValueError(f"At least {len(y_wires)-1} work_wires should be provided.")
147+
if work_wires.intersection(x_wires):
148+
raise ValueError("None of the wires in work_wires should be included in x_wires.")
149+
if work_wires.intersection(y_wires):
150+
raise ValueError("None of the wires in work_wires should be included in y_wires.")
151151
if x_wires.intersection(y_wires):
152152
raise ValueError("None of the wires in y_wires should be included in x_wires.")
153153

154154
self.hyperparameters["x_wires"] = x_wires
155155
self.hyperparameters["y_wires"] = y_wires
156156
self.hyperparameters["work_wires"] = work_wires
157157

158-
all_wires = qml.wires.Wires.all_wires([x_wires, y_wires, work_wires])
158+
if work_wires:
159+
all_wires = qml.wires.Wires.all_wires([x_wires, y_wires, work_wires])
160+
else:
161+
all_wires = qml.wires.Wires.all_wires([x_wires, y_wires])
159162

160163
super().__init__(wires=all_wires, id=id)
161164

@@ -243,6 +246,10 @@ def _semiadder(x_wires, y_wires, work_wires, **_):
243246
num_y_wires = len(y_wires)
244247
num_x_wires = len(x_wires)
245248

249+
if num_y_wires == 1:
250+
qml.CNOT([x_wires[-1], y_wires[0]])
251+
return
252+
246253
x_wires_pl = x_wires[::-1][:num_y_wires]
247254
y_wires_pl = y_wires[::-1]
248255
work_wires_pl = work_wires[::-1]

tests/templates/test_subroutines/test_semi_adder.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ class TestSemiAdder:
3737
@pytest.mark.parametrize(
3838
("x_wires", "y_wires", "work_wires", "x", "y"),
3939
[
40+
([0], [1], None, 1, 0),
41+
([0], [1], None, 1, 1),
42+
([0, 1], [2], None, 0, 1),
43+
([0, 1], [2], None, 1, 1),
44+
([0, 1], [2], None, 1, 0),
45+
([0, 1], [2, 3], [4], 2, 0),
46+
([0, 1], [2, 3], [4], 1, 2),
4047
([0, 1, 2], [3, 4, 5], [6, 7], 1, 2),
4148
([0, 1, 2], [3, 4, 5], [6, 7], 5, 6),
4249
([0, 1], [2, 3, 4], [5, 6], 3, 2),
@@ -63,15 +70,21 @@ def circuit(x, y):
6370

6471
output = circuit(x, y)
6572

73+
if len(y_wires) == 1:
74+
sample = [output[0]]
75+
else:
76+
sample = output[0]
77+
6678
# check that the output sample is the binary representation of x + y mod 2^len(y_wires)
6779
# pylint: disable=bad-reversed-sequence
6880
assert np.allclose(
69-
sum(bit * (2**i) for i, bit in enumerate(reversed(output[0]))),
81+
sum(bit * (2**i) for i, bit in enumerate(reversed(sample))),
7082
(x + y) % 2 ** len(y_wires),
7183
)
7284

73-
# check work_wires are in state |0>
74-
assert np.isclose(output[1][0], 1.0)
85+
if work_wires:
86+
# check work_wires are in state |0>
87+
assert np.isclose(output[1][0], 1.0)
7588

7689
@pytest.mark.parametrize(
7790
("x_wires", "y_wires", "work_wires", "msg_match"),

0 commit comments

Comments
 (0)