-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Keras <> NNX integration #21252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras <> NNX integration #21252
Changes from 110 commits
8e1c008
7159709
e378cfb
8701fc7
b599727
4f7b3b8
0234e27
b87c4f9
91b9a73
aee1789
141487f
5ccc31e
4e35416
dd9c77d
48983c6
b22f9ef
4dbffa6
627e581
f58ef60
c2b73b7
a662f5e
396f973
30e971d
968d804
b99571a
ed0bc00
460e0e2
34f27e9
427ff82
c4ee191
44414dc
0953d99
6d54a7e
64adbaf
6454800
001f112
5f26958
782c653
561f70a
e7caa03
8e3f460
d70d51c
38dbd4b
1c60c5e
74835fd
6f11c0c
c05166e
8582c7e
9471e4e
297775a
f01cc0d
6810848
dc79329
f280dd0
75f9cc8
68261d4
1e09246
8a142a1
c7b2347
99d4307
d544a0b
3b8d90b
0e0fcd1
bd66ec8
8637c18
97d7371
b20321e
46818b2
9064df0
c127c2b
b81030b
e9dbca4
94ea1a2
f0b10ef
d18dd33
51aa455
a02f410
adca8da
d8ca752
95e67e0
a5741be
02d607b
5500382
68d0b68
12eb2a0
a260cb4
23cdfd7
c79a57f
57b42cb
567f120
46db09b
c9b87b1
d4b5afa
34fbeed
3683dc8
53240f9
896ffa0
587bae7
3b4713c
772929c
f84cc1e
05d0119
7cceaae
3af3c30
7494af6
403681b
7a2ddd8
534b975
96f65af
8f2798e
3635148
eccd843
67f90cc
90f3823
a674c03
389b19b
0426eee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
divyashreepathihalli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"floatx": "float32", | ||
"epsilon": 1e-07, | ||
"backend": "jax", | ||
"image_data_format": "channels_last", | ||
"nnx_enabled": true | ||
divyashreepathihalli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,6 +211,9 @@ def __init__( | |
|
||
def _deferred_initialize(self): | ||
if self._value is not None: | ||
# If NNX is enabled, it's possible the variable was already | ||
# initialized by a concrete call. In this case, | ||
# _deferred_initialize becomes a no-op for this variable. | ||
|
||
raise ValueError(f"Variable {self.path} is already initialized.") | ||
|
||
if in_stateless_scope(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,9 @@ | |
# Default backend: TensorFlow. | ||
_BACKEND = "tensorflow" | ||
|
||
# Whether NNX is enabled. | ||
_NNX_ENABLED = False | ||
|
||
# Cap run duration for debugging. | ||
_MAX_EPOCHS = None | ||
_MAX_STEPS_PER_EPOCH = None | ||
|
@@ -230,6 +233,32 @@ def is_flash_attention_enabled(): | |
return global_state.get_global_attribute("flash_attention", default=None) | ||
|
||
|
||
@keras_export("keras.config.is_nnx_enabled") | ||
def is_nnx_enabled(): | ||
"""Checks whether NNX specific features are enabled for the JAX backend. | ||
|
||
Returns: | ||
bool: `True` if NNX backend features are enabled, `False` otherwise. | ||
Defaults to `False`. | ||
""" | ||
return _NNX_ENABLED | ||
|
||
|
||
def set_nnx_enabled(value): | ||
global _NNX_ENABLED | ||
from keras.src.backend.common import global_state | ||
|
||
_NNX_ENABLED = bool(value) | ||
if _NNX_ENABLED: | ||
try: | ||
from flax import nnx # noqa F401 | ||
except ImportError: | ||
raise ImportError( | ||
"To use NNX with the JAX backend, you must install `flax`." | ||
) | ||
global_state.set_global_attribute("nnx_enabled", bool(value)) | ||
|
||
|
||
def standardize_data_format(data_format): | ||
if data_format is None: | ||
return image_data_format() | ||
|
@@ -261,6 +290,30 @@ def keras_home(): | |
|
||
# Attempt to read Keras config file. | ||
_config_path = os.path.expanduser(os.path.join(_KERAS_DIR, "keras.json")) | ||
|
||
# Save config file, if possible. | ||
if not os.path.exists(_KERAS_DIR): | ||
try: | ||
os.makedirs(_KERAS_DIR) | ||
except OSError: | ||
# Except permission denied. | ||
pass | ||
|
||
if not os.path.exists(_config_path): | ||
_config_to_save = { | ||
"floatx": _FLOATX, | ||
"epsilon": _EPSILON, | ||
"backend": _BACKEND, | ||
"image_data_format": _IMAGE_DATA_FORMAT, | ||
divyashreepathihalli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"nnx_enabled": is_nnx_enabled(), | ||
divyashreepathihalli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} | ||
try: | ||
with open(_config_path, "w") as f: | ||
f.write(json.dumps(_config_to_save, indent=4)) | ||
except IOError: | ||
# Except permission denied. | ||
pass | ||
|
||
if os.path.exists(_config_path): | ||
try: | ||
with open(_config_path) as f: | ||
|
@@ -274,36 +327,17 @@ def keras_home(): | |
_backend = _config.get("backend", _BACKEND) | ||
_image_data_format = _config.get("image_data_format", image_data_format()) | ||
assert _image_data_format in {"channels_last", "channels_first"} | ||
_nnx_enabled_config = _config.get("nnx_enabled", _NNX_ENABLED) | ||
if isinstance(_nnx_enabled_config, bool): | ||
_NNX_ENABLED = _nnx_enabled_config | ||
divyashreepathihalli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# else: ignore non-bool values for nnx_enabled | ||
divyashreepathihalli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Apply basic configs that don't cause circular import | ||
set_floatx(_floatx) | ||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
set_epsilon(_epsilon) | ||
set_image_data_format(_image_data_format) | ||
_BACKEND = _backend | ||
|
||
# Save config file, if possible. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you need to move this block? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to resolve some circular import errors |
||
if not os.path.exists(_KERAS_DIR): | ||
try: | ||
os.makedirs(_KERAS_DIR) | ||
except OSError: | ||
# Except permission denied and potential race conditions | ||
# in multi-threaded environments. | ||
pass | ||
|
||
if not os.path.exists(_config_path): | ||
_config = { | ||
"floatx": floatx(), | ||
"epsilon": epsilon(), | ||
"backend": _BACKEND, | ||
"image_data_format": image_data_format(), | ||
} | ||
try: | ||
with open(_config_path, "w") as f: | ||
f.write(json.dumps(_config, indent=4)) | ||
except IOError: | ||
# Except permission denied. | ||
pass | ||
|
||
# Set backend based on KERAS_BACKEND flag, if applicable. | ||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if "KERAS_BACKEND" in os.environ: | ||
_backend = os.environ["KERAS_BACKEND"] | ||
if _backend: | ||
|
@@ -313,6 +347,7 @@ def keras_home(): | |
if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ: | ||
_MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"]) | ||
|
||
|
||
if _BACKEND != "tensorflow": | ||
# If we are not running on the tensorflow backend, we should stop tensorflow | ||
# from using all available GPU memory. See | ||
|
@@ -403,3 +438,13 @@ def max_steps_per_epoch(): | |
`None`, no limit is applied. | ||
""" | ||
return _MAX_STEPS_PER_EPOCH | ||
|
||
|
||
if "KERAS_NNX_ENABLED" in os.environ: | ||
env_val = os.environ["KERAS_NNX_ENABLED"].lower() | ||
if env_val: | ||
|
||
_NNX_ENABLED = True | ||
else: | ||
_NNX_ENABLED = False | ||
|
||
set_nnx_enabled(_NNX_ENABLED) |
Uh oh!
There was an error while loading. Please reload this page.