Skip to content

qml.sample() fails with OutDBIdx shape canonicalization error in dynamic one-shot context #1949

@rniczh

Description

@rniczh

Context

When using @qjit with finite shots, qml.sample() measurements fail with a shape canonicalization error involving OutDBIdx references, while qml.expval() measurements work correctly. This occurs in the dynamic_one_shot transformation that's automatically applied when using finite shots and in the case of wires is not set to qml.qnode.

Reproduction

❌ Failing Case (qml.sample())

import pennylane as qml
from catalyst import qjit

backend = "lightning.qubit"

@qjit
@qml.qnode(qml.device(backend, shots=10), mcm_method='one-shot')
def circuit():
    qml.RX(0.0, wires=3)
    return qml.sample()

circuit()

✅ Working Case (qml.expval())

@qjit
@qml.qnode(qml.device(backend, shots=10), mcm_method='one-shot')
def circuit():
    qml.RX(0.0, wires=3)
    return qml.expval(qml.PauliZ(0))  # This works

circuit()

Why expval() works

  • No OutDBIdx references are created, the shape for expval(...) case is ShapedType which can be handled by jnp.zeros()

Full Stack Trace

Traceback (most recent call last):
  File "/path/to/test.py", line 7, in <module>
    @qjit
     ^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 502, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 565, in __init__
    self.aot_compile()
    ~~~~~~~~~~~~~~~~^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 618, in aot_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
                                                              ~~~~~~~~~~~~^
        self.user_sig or ()
        ^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/path/to/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 759, in capture
    jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
                                        ~~~~~~~~~~~~~~^
        self.user_function, static_argnums, abstracted_axes, full_sig, kwargs, dbg
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jax_tracer.py", line 613, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
                                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/catalyst/frontend/catalyst/jax_extras/tracing.py", line 499, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
                              ~~~~~~~~~~~~~~~~~~~~~~~^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 749, in closure
    return QFunc.__call__(
           ~~~~~~~~~~~~~~^
        qnode,
        ^^^^^^
        *args,
        ^^^^^^
        **dict(params, **kwargs),
        ^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/qfunc.py", line 143, in __call__
    return Function(dynamic_one_shot(self, mcm_config=mcm_config))(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jax_tracer.py", line 181, in __call__
    jaxpr, _, out_tree = make_jaxpr2(
                         ~~~~~~~~~~~~
        self.fn,
        ~~~~~~~~
        debug_info=kwargs.pop("debug_info", jdb("Function", self.fn, args, kwargs)),
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*args, **kwargs)
    ~^^^^^^^^^^^^^^^^^
  File "/path/to/catalyst/frontend/catalyst/jax_extras/tracing.py", line 499, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
                              ~~~~~~~~~~~~~~~~~~~~~~~^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/qfunc.py", line 286, in one_shot_wrapper
    results = catalyst.vmap(wrap_single_shot_qnode)(arg_vmap)
  File "/path/to/catalyst/frontend/catalyst/api_extensions/function_maps.py", line 235, in __call__
    init_result_flat = [jnp.zeros(shape=shape.shape, dtype=shape.dtype) for shape, _ in shapes]
                        ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/numpy/array_creation.py", line 82, in zeros
    shape = canonicalize_shape(shape)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/numpy/array_creation.py", line 45, in canonicalize_shape
    return core.canonicalize_shape(shape, context)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/core.py", line 1864, in canonicalize_shape
    raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of integer scalars, got (1, OutDBIdx(val=0))

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions