Skip to content

Conversation

divyashreepathihalli
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli commented May 5, 2025

The PR integrates NNX into JAX backend!

The following snippet shows how you would enable the nnx backend

import os
os.environ["KERAS_BACKEND"]="jax"
os.environ["KERAS_NNX_ENABLED"]="true"
import keras

Demo colab here : https://colab.sandbox.google.com/drive/1mK-4qbce2HGRIkcb4v5n4niWGDezL_6n#scrollTo=m-ZH9Mpnphfz
Added a github workflow action for nnx backend. Note this will fail - because this needs a new release of flax to work.

@divyashreepathihalli divyashreepathihalli marked this pull request as draft May 5, 2025 23:05
@codecov-commenter
Copy link

codecov-commenter commented May 5, 2025

Codecov Report

Attention: Patch coverage is 29.13907% with 107 lines in your changes missing coverage. Please review.

Project coverage is 82.72%. Comparing base (3554825) to head (0426eee).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/core.py 6.49% 71 Missing and 1 partial ⚠️
keras/src/backend/config.py 52.38% 8 Missing and 2 partials ⚠️
keras/src/ops/function.py 50.00% 6 Missing and 2 partials ⚠️
keras/src/backend/jax/trainer.py 66.66% 2 Missing and 2 partials ⚠️
keras/src/layers/layer.py 50.00% 2 Missing and 2 partials ⚠️
keras/src/backend/jax/layer.py 50.00% 2 Missing and 1 partial ⚠️
keras/src/ops/operation.py 25.00% 2 Missing and 1 partial ⚠️
keras/src/backend/common/variables.py 0.00% 1 Missing and 1 partial ⚠️
keras/api/_tf_keras/keras/config/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21252      +/-   ##
==========================================
- Coverage   82.87%   82.72%   -0.16%     
==========================================
  Files         567      567              
  Lines       56073    56214     +141     
  Branches     8756     8786      +30     
==========================================
+ Hits        46470    46501      +31     
- Misses       7459     7556      +97     
- Partials     2144     2157      +13     
Flag Coverage Δ
keras 82.52% <29.13%> (-0.16%) ⬇️
keras-jax 63.92% <29.13%> (-0.11%) ⬇️
keras-numpy 58.41% <21.19%> (-0.11%) ⬇️
keras-openvino 34.56% <19.86%> (-0.05%) ⬇️
keras-tensorflow 64.34% <21.85%> (-0.12%) ⬇️
keras-torch 63.97% <21.85%> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

x = ops.ones(3)

@jax.jit
@nnx.jit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the integration prevent the use of jax.jit with Keras layers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! it would only work with nnx.jit for now ( They might be working on adding support for jax.jit)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added nnx as a opt in with this flag - os.environ["KERAS_NNX_ENABLED"]

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some final comments on the config and configuration part. All good on my end once fixed.


if "KERAS_NNX_ENABLED" in os.environ:
env_val = os.environ["KERAS_NNX_ENABLED"].lower()
if env_val:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems bad, don't we want to check if val.lower() = "true" or "1" here? Otherwise KERAS_NNX_ENABLED=False will enable nnx.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jul 24, 2025
if self._value is not None:
# If NNX is enabled, it's possible the variable was already
# initialized by a concrete call. In this case,
# _deferred_initialize becomes a no-op for this variable.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this comment is incorrect. It raises an error in this case, so it's not a no-op.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the comment and also if nnx is enabled, I added a return instead of retuning an error.

@google-ml-butler google-ml-butler bot added kokoro:force-run and removed ready to pull Ready to be merged into the codebase labels Jul 24, 2025

# Apply basic configs that don't cause circular import
set_floatx(_floatx)
_NNX_ENABLED = _nnx_enabled_config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we literally call set_nnx_enabled(_nnx_enabled_config) (just going for consistency)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this leads to circular import error

    from keras.src.utils import dataset_utils
keras/src/utils/dataset_utils.py:9: in <module>
    from keras.src import tree
keras/src/tree/__init__.py:1: in <module>
    from keras.src.tree.tree_api import assert_same_paths
keras/src/tree/tree_api.py:8: in <module>
    from keras.src.tree import optree_impl as tree_impl
keras/src/tree/optree_impl.py:4: in <module>
    from keras.src.backend.config import backend
E   ImportError: cannot import name 'backend' from partially initialized module 'keras.src.backend.config' (most likely due to a circular import) (/workspaces/keras/keras/src/backend/config.py)

So decided to revert it. 2 lines below backend is also set in the same way _BACKEND = _backend - in a way it is still consistent?

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jul 25, 2025
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Jul 25, 2025
@divyashreepathihalli divyashreepathihalli merged commit c9383e2 into keras-team:master Jul 25, 2025
8 of 11 checks passed
@hertschuh hertschuh mentioned this pull request Aug 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants