Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pennylane/devices/qubit_mixed/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def apply_diagonal_unitary(op, state, is_state_batched: bool = False, debugger=N

eigvals = op.eigvals()
eigvals = math.stack(eigvals)
eigvals = math.reshape(eigvals, [2] * len(channel_wires))
eigvals = math.reshape(eigvals, [-1, 2 * len(channel_wires)])
eigvals = math.cast_like(eigvals, state)

state_indices = alphabet[: 2 * num_wires + is_state_batched]
Expand All @@ -654,7 +654,8 @@ def apply_diagonal_unitary(op, state, is_state_batched: bool = False, debugger=N
col_indices = "".join(alphabet_array[col_wires_list].tolist())

# Basically, we want to do, lambda_a rho_ab lambda_b
einsum_indices = f"{row_indices},{state_indices},{col_indices}->{state_indices}"
# Use ellipsis to represent the batch dimensions as eigvals.shape = (batch, 2^k)
einsum_indices = f"...{row_indices},...{state_indices},...{col_indices}->...{state_indices}"

return math.einsum(einsum_indices, eigvals, state, math.conj(eigvals))

Expand Down
Loading