Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion keras/src/ops/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def __new__(cls, *args, **kwargs):
if backend.backend() == "jax" and is_nnx_enabled():
from flax import nnx

vars(instance)["_object__state"] = nnx.object.ObjectState()
if "_object__state" in vars(instance):
vars(instance)["_object__state"] = nnx.object.ObjectState()
else:
vars(instance)["_pytree__state"] = nnx.pytreelib.PytreeState()
# Generate a config to be returned by default by `get_config()`.
arg_names = inspect.getfullargspec(cls.__init__).args
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
Expand Down
Loading