Skip to content

[BUG] qml.RZ doesn’t work with torch.vmap but qml.RX and qml.RY do work #8759

@CatalinaAlbornoz

Description

@CatalinaAlbornoz

Expected behavior

qml.RZ works with torch.vmap

Actual behavior

As reported initially in thread 9133 of the Discussion Forum, qml.RZ doesn’t work with torch.vmap but qml.RX and qml.RY do work.

Additional information

The problem is potentially the cast_like on line 507

Source code

import pennylane as qml
import torch

dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev, interface="torch")
def ansatz(x):
    qml.RZ(x, wires=0)
    return qml.expval(qml.PauliZ(0))

x = torch.tensor([[[0.1, 0.2], [0.3, 0.4]],[[0.1, 0.2], [0.3, 0.4]]])
res = torch.vmap(
    lambda x_i: torch.vmap(
        lambda x_j: ansatz(x_j), in_dims=0)(x_i),
        in_dims=0
    )(x)

print(res)

Tracebacks

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipython-input-1739476951.py in <cell line: 0>()
     10 
     11 x = torch.tensor([[[0.1, 0.2], [0.3, 0.4]],[[0.1, 0.2], [0.3, 0.4]]])
---> 12 res = torch.vmap(
     13     lambda x_i: torch.vmap(
     14         lambda x_j: ansatz(x_j), in_dims=0)(x_i),

32 frames
/usr/local/lib/python3.12/dist-packages/torch/_functorch/apis.py in wrapped(*args, **kwargs)
    206 
    207     def wrapped(*args, **kwargs):
--> 208         return vmap_impl(
    209             func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    210         )

/usr/local/lib/python3.12/dist-packages/torch/_functorch/vmap.py in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    280 
    281     # If chunk_size is not specified.
--> 282     return _flat_vmap(
    283         func,
    284         batch_size,

/usr/local/lib/python3.12/dist-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    430             flat_in_dims, flat_args, vmap_level, args_spec
    431         )
--> 432         batched_outputs = func(*batched_inputs, **kwargs)
    433         return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    434 

/tmp/ipython-input-1739476951.py in <lambda>(x_i)
     11 x = torch.tensor([[[0.1, 0.2], [0.3, 0.4]],[[0.1, 0.2], [0.3, 0.4]]])
     12 res = torch.vmap(
---> 13     lambda x_i: torch.vmap(
     14         lambda x_j: ansatz(x_j), in_dims=0)(x_i),
     15         in_dims=0

/usr/local/lib/python3.12/dist-packages/torch/_functorch/apis.py in wrapped(*args, **kwargs)
    206 
    207     def wrapped(*args, **kwargs):
--> 208         return vmap_impl(
    209             func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    210         )

/usr/local/lib/python3.12/dist-packages/torch/_functorch/vmap.py in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    280 
    281     # If chunk_size is not specified.
--> 282     return _flat_vmap(
    283         func,
    284         batch_size,

/usr/local/lib/python3.12/dist-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    430             flat_in_dims, flat_args, vmap_level, args_spec
    431         )
--> 432         batched_outputs = func(*batched_inputs, **kwargs)
    433         return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    434 

/tmp/ipython-input-1739476951.py in <lambda>(x_j)
     12 res = torch.vmap(
     13     lambda x_i: torch.vmap(
---> 14         lambda x_j: ansatz(x_j), in_dims=0)(x_i),
     15         in_dims=0
     16     )(x)

/usr/local/lib/python3.12/dist-packages/pennylane/workflow/qnode.py in __call__(self, *args, **kwargs)
    893 
    894             return capture_qnode(self, *args, **kwargs)
--> 895         return self._impl_call(*args, **kwargs)
    896 
    897 

/usr/local/lib/python3.12/dist-packages/pennylane/workflow/qnode.py in _impl_call(self, *args, **kwargs)
    866         self._transform_program.set_classical_component(self, args, kwargs)
    867 
--> 868         res = execute(
    869             (tape,),
    870             device=self.device,

/usr/local/lib/python3.12/dist-packages/pennylane/workflow/execution.py in execute(tapes, device, diff_method, interface, grad_on_execution, cache, cachesize, max_diff, device_vjp, postselect_mode, mcm_method, gradient_kwargs, transform_program, executor_backend)
    236     assert not outer_transform.is_informative, "should only contain device preprocessing"
    237 
--> 238     results = run(tapes, device, config, inner_transform)
    239     return user_post_processing(outer_post_processing(results))

/usr/local/lib/python3.12/dist-packages/pennylane/workflow/run.py in run(tapes, device, config, inner_transform_program)
    296     )
    297     if no_interface_boundary_required:
--> 298         results = inner_execute(tapes)
    299         return results
    300 

/usr/local/lib/python3.12/dist-packages/pennylane/workflow/run.py in inner_execute(tapes)
    261 
    262         if transformed_tapes:
--> 263             results = device.execute(transformed_tapes, execution_config=execution_config)
    264         else:
    265             results = ()

/usr/local/lib/python3.12/dist-packages/pennylane/devices/modifiers/simulator_tracking.py in execute(self, circuits, execution_config)
     26     def execute(self, circuits, execution_config: ExecutionConfig | None = None):
     27 
---> 28         results = untracked_execute(self, circuits, execution_config)
     29         if isinstance(circuits, QuantumScript):
     30             batch = (circuits,)

/usr/local/lib/python3.12/dist-packages/pennylane/devices/modifiers/single_tape_support.py in execute(self, circuits, execution_config)
     28             is_single_circuit = True
     29             circuits = (circuits,)
---> 30         results = batch_execute(self, circuits, execution_config)
     31         return results[0] if is_single_circuit else results
     32 

/usr/local/lib/python3.12/dist-packages/pennylane/logging/decorators.py in wrapper_entry(*args, **kwargs)
     59                 **_debug_log_kwargs,
     60             )
---> 61         return func(*args, **kwargs)
     62 
     63     @wraps(func)

/usr/local/lib/python3.12/dist-packages/pennylane/devices/default_qubit.py in execute(self, circuits, execution_config)
    821 
    822         if max_workers is None:
--> 823             return tuple(
    824                 _simulate_wrapper(
    825                     c,

/usr/local/lib/python3.12/dist-packages/pennylane/devices/default_qubit.py in <genexpr>(.0)
    822         if max_workers is None:
    823             return tuple(
--> 824                 _simulate_wrapper(
    825                     c,
    826                     {

/usr/local/lib/python3.12/dist-packages/pennylane/devices/default_qubit.py in _simulate_wrapper(circuit, kwargs)
   1187 
   1188 def _simulate_wrapper(circuit, kwargs):
-> 1189     return simulate(circuit, **kwargs)
   1190 
   1191 

/usr/local/lib/python3.12/dist-packages/pennylane/logging/decorators.py in wrapper_entry(*args, **kwargs)
     59                 **_debug_log_kwargs,
     60             )
---> 61         return func(*args, **kwargs)
     62 
     63     @wraps(func)

/usr/local/lib/python3.12/dist-packages/pennylane/devices/qubit/simulate.py in simulate(circuit, debugger, state_cache, **execution_kwargs)
    368 
    369     ops_key, meas_key = jax_random_split(prng_key)
--> 370     state, is_state_batched = get_final_state(
    371         circuit, debugger=debugger, prng_key=ops_key, **execution_kwargs
    372     )

/usr/local/lib/python3.12/dist-packages/pennylane/logging/decorators.py in wrapper_entry(*args, **kwargs)
     59                 **_debug_log_kwargs,
     60             )
---> 61         return func(*args, **kwargs)
     62 
     63     @wraps(func)

/usr/local/lib/python3.12/dist-packages/pennylane/devices/qubit/simulate.py in get_final_state(circuit, debugger, **execution_kwargs)
    199         if isinstance(op, MidMeasureMP):
    200             prng_key, key = jax_random_split(prng_key)
--> 201         state = apply_operation(
    202             op,
    203             state,

/usr/lib/python3.12/functools.py in wrapper(*args, **kw)
    910                             '1 positional argument')
    911 
--> 912         return dispatch(args[0].__class__)(*args, **kw)
    913 
    914     funcname = getattr(func, '__name__', 'singledispatch function')

/usr/local/lib/python3.12/dist-packages/pennylane/devices/qubit/apply_operation.py in apply_operation(op, state, is_state_batched, debugger, **_)
    235 
    236     """
--> 237     return _apply_operation_default(op, state, is_state_batched, debugger)
    238 
    239 

/usr/local/lib/python3.12/dist-packages/pennylane/devices/qubit/apply_operation.py in _apply_operation_default(op, state, is_state_batched, debugger)
    261         and math.ndim(state) < EINSUM_STATE_WIRECOUNT_PERF_THRESHOLD
    262     ) or (op.batch_size and is_state_batched):
--> 263         return apply_operation_einsum(op, state, is_state_batched=is_state_batched)
    264     return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
    265 

/usr/local/lib/python3.12/dist-packages/pennylane/devices/qubit/apply_operation.py in apply_operation_einsum(op, state, is_state_batched)
     82         mat = math.cast_like(op.matrix(), state)
     83     else:
---> 84         mat = op.matrix() + 0j
     85 
     86     total_indices = len(state.shape) - is_state_batched

/usr/local/lib/python3.12/dist-packages/pennylane/operation.py in matrix(self, wire_order)
    829             tensor_like: matrix representation
    830         """
--> 831         canonical_matrix = self.compute_matrix(*self.parameters, **self.hyperparameters)
    832 
    833         if (

/usr/local/lib/python3.12/dist-packages/pennylane/ops/qubit/parametric_ops_single_qubit.py in compute_matrix(theta)
    463 
    464         diags = qml.math.exp(qml.math.outer(arg, signs))
--> 465         return diags[:, :, np.newaxis] * qml.math.cast_like(qml.math.eye(2, like=diags), diags)
    466 
    467     @staticmethod

/usr/local/lib/python3.12/dist-packages/pennylane/math/utils.py in cast_like(tensor1, tensor2)
    289         dtype = ar.to_numpy(tensor2._value).dtype.type  # pylint: disable=protected-access
    290     elif not is_abstract(tensor2):
--> 291         dtype = ar.to_numpy(tensor2).dtype.type
    292     else:
    293         dtype = tensor2.dtype

/usr/local/lib/python3.12/dist-packages/autoray/autoray.py in to_numpy(x)
   1130 def to_numpy(x):
   1131     """Get a numpy version of array ``x``."""
-> 1132     return do("to_numpy", x)
   1133 
   1134 

/usr/local/lib/python3.12/dist-packages/autoray/autoray.py in do(fn, like, *args, **kwargs)
     79     backend = _choose_backend(fn, args, kwargs, like=like)
     80     func = get_lib_fn(backend, fn)
---> 81     return func(*args, **kwargs)
     82 
     83 

/usr/local/lib/python3.12/dist-packages/pennylane/math/single_dispatch.py in _to_numpy_torch(x)
    639         x = x.resolve_conj()
    640 
--> 641     return x.detach().cpu().numpy()
    642 
    643 

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

System information

Name: pennylane
Version: 0.43.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: 
Author: 
Author-email: 
License: 
Location: /usr/local/lib/python3.12/dist-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing_extensions
Required-by: pennylane_lightning

Platform info:           Linux-6.6.105+-x86_64-with-glibc2.35
Python version:          3.12.12
Numpy version:           2.0.2
Scipy version:           1.16.3
JAX version:             0.7.2
Installed devices:
- lightning.qubit (pennylane_lightning-0.43.0)
- default.clifford (pennylane-0.43.1)
- default.gaussian (pennylane-0.43.1)
- default.mixed (pennylane-0.43.1)
- default.qubit (pennylane-0.43.1)
- default.qutrit (pennylane-0.43.1)
- default.qutrit.mixed (pennylane-0.43.1)
- default.tensor (pennylane-0.43.1)
- null.qubit (pennylane-0.43.1)
- reference.qubit (pennylane-0.43.1)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug 🐛Something isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions