-
Couldn't load subscription status.
- Fork 706
feat: autograph: add support for pythonic index assignment w/ jax.numpy arrays
#8027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This comment was marked as resolved.
This comment was marked as resolved.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #8027 +/- ##
=======================================
Coverage 99.68% 99.68%
=======================================
Files 545 545
Lines 56137 56143 +6
=======================================
+ Hits 55961 55967 +6
Misses 176 176 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
autograph modulejax.numpy arrays
jax.numpy arraysjax.numpy arrays
jax.numpy arraysjax.numpy arrays
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor unblocking comments only.
… `-=`, ...) (#8076) ## Context This is an in-direct follow-up PR to #8027. Now that we have index assignment. We wish to add in-place array updates using pythonic notation. ## Description of the Change This PR directly mirrors the changes made to Catalyst [here](PennyLaneAI/catalyst#1143). Basically, the gist is that we added the `SingleIndexArrayOperatorUpdateTransformer` class that is able to use autograph to properly convert any detected array update in our AST. ## Benefits This allows us to do things like, ```python import pennylane as qml qml.capture.enable() from pennylane.capture import make_plxpr import jax import jax.numpy as jnp def f(): # setup my_list = [0, 1] my_list = jnp.empty(2, dtype=int) for i in range(2): my_list[i] = i # increment second element by ten my_list[1] += 10 return my_list >>> plxpr = make_plxpr(f)() >>> jax.core.eval_jaxpr(plxpr.jaxpr, []) [Array([0, 11], dtype=int32)] ``` ## Possible Drawbacks None identified. [sc-76134] --------- Co-authored-by: Pietropaolo Frisoni <[email protected]>
Context
Current autograph implementation in PL only supports converting functions that use the
jax.numpysyntaxarr = arr.at[i].set(x)to update array indices.We would like to be able to do
arr[i] = xinstead of raising an error.Description of the Change
Updates
ag_primitivesto allow the transformer to update the indices accordingly.Benefits
We now have functional code,
Possible Drawbacks
None identified.
[sc-76133]