Skip to content

Commit 26c7d88

Browse files
committed
get rid of try and except around make jaxpr
1 parent 03fac7f commit 26c7d88

File tree

1 file changed

+3
-15
lines changed

1 file changed

+3
-15
lines changed

pennylane/workflow/_capture_qnode.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@
117117

118118
import pennylane as qml
119119
from pennylane.capture import FlatFn, QmlPrimitive
120-
from pennylane.exceptions import CaptureError
121120
from pennylane.logging import debug_logger
122121
from pennylane.typing import TensorLike
123122

@@ -459,20 +458,9 @@ def _extract_qfunc_jaxpr(qnode, abstracted_axes, *args, **kwargs):
459458
qfunc = partial(qnode.func, **kwargs) if kwargs else qnode.func
460459
flat_fn = FlatFn(qfunc)
461460

462-
try:
463-
qfunc_jaxpr = jax.make_jaxpr(
464-
flat_fn, abstracted_axes=abstracted_axes, static_argnums=qnode.static_argnums
465-
)(*args)
466-
except (
467-
jax.errors.TracerArrayConversionError,
468-
jax.errors.TracerIntegerConversionError,
469-
jax.errors.TracerBoolConversionError,
470-
) as exc:
471-
raise CaptureError(
472-
"Autograph must be used when Python control flow is dependent on a dynamic "
473-
"variable (a function input). Please ensure that autograph=True or use native control "
474-
"flow functions like for_loop, while_loop, etc."
475-
) from exc
461+
qfunc_jaxpr = jax.make_jaxpr(
462+
flat_fn, abstracted_axes=abstracted_axes, static_argnums=qnode.static_argnums
463+
)(*args)
476464

477465
assert flat_fn.out_tree is not None, "out_tree should be set by call to flat_fn"
478466
return qfunc_jaxpr, flat_fn.out_tree

0 commit comments

Comments
 (0)