Skip to content

Conversation

@andrijapau
Copy link
Contributor

@andrijapau andrijapau commented Aug 5, 2025

Context

Current autograph implementation in PL only supports converting functions that use the jax.numpy syntax arr = arr.at[i].set(x) to update array indices.

We would like to be able to do arr[i] = x instead of raising an error.

Description of the Change

Updates ag_primitives to allow the transformer to update the indices accordingly.

Benefits

We now have functional code,

import pennylane as qml

qml.capture.enable()

from pennylane.capture import make_plxpr

import jax
import jax.numpy as jnp

def f():
    
    my_list = jnp.empty(2, dtype=int)
    for i in range(2):
        my_list[i] = i
    return my_list

>>> plxpr = make_plxpr(f)()
>>> jax.core.eval_jaxpr(plxpr.jaxpr, [])
[Array([0, 1], dtype=int32)]

Possible Drawbacks

None identified.

[sc-76133]

@andrijapau andrijapau added the WIP 🚧 Work-in-progress label Aug 5, 2025
@github-actions

This comment was marked as resolved.

@andrijapau andrijapau marked this pull request as ready for review August 8, 2025 20:28
@codecov
Copy link

codecov bot commented Aug 11, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 99.68%. Comparing base (bd77b0e) to head (80393b3).
⚠️ Report is 9 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #8027   +/-   ##
=======================================
  Coverage   99.68%   99.68%           
=======================================
  Files         545      545           
  Lines       56137    56143    +6     
=======================================
+ Hits        55961    55967    +6     
  Misses        176      176           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@andrijapau andrijapau changed the title [WIP] add support for item assignment with autograph module feat: autograph: add support for item assignment for jax.numpy arrays Aug 11, 2025
@andrijapau andrijapau changed the title feat: autograph: add support for item assignment for jax.numpy arrays feat: autograph: add support for item assignment with jax.numpy arrays Aug 11, 2025
@andrijapau andrijapau removed the WIP 🚧 Work-in-progress label Aug 18, 2025
@andrijapau andrijapau changed the title feat: autograph: add support for item assignment with jax.numpy arrays feat: autograph: add support for pythonic index assignment w/ jax.numpy arrays Aug 18, 2025
Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing

Copy link
Contributor

@comp-phys-marc comp-phys-marc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor unblocking comments only.

@andrijapau andrijapau added this pull request to the merge queue Aug 21, 2025
Merged via the queue into master with commit fbaca7a Aug 21, 2025
52 checks passed
@andrijapau andrijapau deleted the autograph/support-item-assignment branch August 21, 2025 14:37
github-merge-queue bot pushed a commit that referenced this pull request Aug 25, 2025
… `-=`, ...) (#8076)

## Context

This is an in-direct follow-up PR to
#8027.

Now that we have index assignment. We wish to add in-place array updates
using pythonic notation.

## Description of the Change

This PR directly mirrors the changes made to Catalyst
[here](PennyLaneAI/catalyst#1143). Basically,
the gist is that we added the
`SingleIndexArrayOperatorUpdateTransformer` class that is able to use
autograph to properly convert any detected array update in our AST.

## Benefits

This allows us to do things like,

```python
import pennylane as qml

qml.capture.enable()

from pennylane.capture import make_plxpr

import jax
import jax.numpy as jnp

def f():
    # setup my_list = [0, 1]
    my_list = jnp.empty(2, dtype=int)
    for i in range(2):
        my_list[i] = i
    # increment second element by ten
    my_list[1] += 10
    return my_list

>>> plxpr = make_plxpr(f)()
>>> jax.core.eval_jaxpr(plxpr.jaxpr, [])
[Array([0, 11], dtype=int32)]
```

## Possible Drawbacks

None identified.

[sc-76134]

---------

Co-authored-by: Pietropaolo Frisoni <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants