-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[RLlib] MetricsLogger: Fix get/set_state
to handle tensors in self.values
.
#53514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] MetricsLogger: Fix get/set_state
to handle tensors in self.values
.
#53514
Conversation
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR updates Stats.get_state
and Stats.from_state
to correctly handle tensor values in self.values
, ensuring state serialization avoids returning raw tensors.
- Introduces
convert_to_numpy
inget_state
to serialize any tensors to NumPy. - Adds a
_could_be_tensor
flag to track potential tensor insertions and skips repopulatingvalues
infrom_state
when tensors were present. - Refactors
check_value
to detect zero-dimensional tensors and enforce scalar reduction behavior.
Comments suppressed due to low confidence (3)
rllib/utils/metrics/stats.py:187
- [nitpick] The private flag
_could_be_tensor
may be clearer as_may_have_tensors
or_has_tensors
to reflect its boolean purpose more directly.
self._could_be_tensor = False
rllib/utils/metrics/stats.py:639
- There are no explicit unit tests covering tensor serialization in
get_state
and recovery infrom_state
. Adding tests for pushing Torch/TF tensors and ensuring correct NumPy output will prevent regressions.
def get_state(self) -> Dict[str, Any]:
rllib/utils/metrics/stats.py:642
convert_to_numpy
may not handledeque
inputs uniformly. It could be safer to convertself.values
to a list first, e.g.,convert_to_numpy(list(self.values))
.
"values": convert_to_numpy(self.values),
rllib/utils/metrics/stats.py
Outdated
@@ -350,17 +358,20 @@ class for details on the reduction logic applied to the values list, based on | |||
""" | |||
len_before_reduce = len(self) | |||
if self._has_new_values: | |||
# Only calculate and update history if there were new values pushed since last reduce | |||
# Only calculate and update history if there were new values pushed since' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a stray apostrophe after since
. Removing it will clean up the comment.
# Only calculate and update history if there were new values pushed since' | |
# Only calculate and update history if there were new values pushed since |
Copilot uses AI. Check for mistakes.
} | ||
if self._throughput_stats is not None: | ||
state["throughput_stats"] = self._throughput_stats.get_state() | ||
return state | ||
|
||
@staticmethod | ||
def from_state(state: Dict[str, Any], throughputs=False) -> "Stats": | ||
def from_state(state: Dict[str, Any]) -> "Stats": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The throughputs
parameter was removed from from_state
, which may break existing callers relying on that signature. Consider retaining a backward-compatible overload or documenting the change.
Copilot uses AI. Check for mistakes.
# whether we are on a supported device). | ||
values = state["values"] | ||
if "_could_be_tensor" in state and state["_could_be_tensor"]: | ||
values = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- If we do this, we should state it as a known limitation in metrics-logger.rst.
- Alternatively, can we noch check whether we are on a supported device and keep track of GPU tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- That's true. This is a limitation for now. Albeit a very small one since - normally - tensor logging is only done for loss metrics, like loss, entropy, etc.. and these are very very often ephemeral values where it's not a problem at all to just start fresh after a checkpoint loading (
window=1
anyways?). - Yeah, but then we would have to store the original device as well, which quickly becomes super messy (when you transfer a checkpoint from one cluster type (GPUs) to another (no GPUs?)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- What if
window!=1
? It's a user-facing class. - Rodger! Yeah there is some ugly complexity here.
rllib/utils/metrics/stats.py
Outdated
@@ -183,29 +184,33 @@ def __init__( | |||
self.values: Union[List, deque.Deque] = None | |||
self._set_values(force_list(init_values)) | |||
|
|||
self._could_be_tensor = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From reading the code self._could_be_tensor
can be renamed to self._is_tensor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments but - no blockers.
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…ics_logger_get_set_state_handling_tensors
Signed-off-by: sven1977 <[email protected]>
….values`. (#53514) Signed-off-by: elliot-barn <[email protected]>
….values`. (#53514) Signed-off-by: elliot-barn <[email protected]>
MetricsLogger: Fix
get/set_state
to handle tensors inself.values
.self._may_have_tensors
flag to be True.self._may_have_tensors
flag and does NOT populate theself.values
field, if it's True.Why are these changes needed?
Related issue number
Closes #53467
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.