Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,7 @@ def fit(

# Override with model metrics instead of last step logs if
# needed.
# The jax spmd_mode is need for multi-process context, since the
# metrics values are replicated, and we don't want to do a all
# gather, and only need the local copy of the value.
with jax.spmd_mode("allow_all"):
epoch_logs = dict(self._get_metrics_result_or_logs(logs))
epoch_logs = dict(self._get_metrics_result_or_logs(logs))

# Run validation.
if validation_data is not None and self._should_eval(
Expand Down Expand Up @@ -605,11 +601,7 @@ def evaluate(
# Reattach state back to model (if not already done by a callback).
self.jax_state_sync()

# The jax spmd_mode is need for multi-process context, since the
# metrics values are replicated, and we don't want to do a all
# gather, and only need the local copy of the value.
with jax.spmd_mode("allow_all"):
logs = self._get_metrics_result_or_logs(logs)
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)
self._jax_state = None
if not use_cached_eval_dataset:
Expand Down