Skip to content

Commit caf3680

Browse files
[RLlib] MetricsLogger + Stats overhaul. (#51639)
1 parent 5aa5462 commit caf3680

File tree

14 files changed

+2022
-717
lines changed

14 files changed

+2022
-717
lines changed

doc/source/rllib/metrics-logger.rst

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,19 @@ to :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`:
181181
.. testcode::
182182

183183
logger.log_value("some_items", value="a", reduce=None, clear_on_reduce=True)
184-
logger.log_value("some_items", value="b")
185-
logger.log_value("some_items", value="c")
186-
logger.log_value("some_items", value="d")
184+
logger.log_value("some_items", value="b", reduce=None, clear_on_reduce=True)
185+
logger.log_value("some_items", value="c", reduce=None, clear_on_reduce=True)
186+
logger.log_value("some_items", value="d", reduce=None, clear_on_reduce=True)
187187

188188
logger.peek("some_items") # expect a list: ["a", "b", "c", "d"]
189189

190190
logger.reduce()
191191
logger.peek("some_items") # expect an empty list: []
192192

193+
You should pass additional arguments like ``reduce=None`` and ``clear_on_reduce=True`` to the
194+
:py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.log_value` method on each call.
195+
Otherwise, MetricsLogger will emit warnings to ensure that it's behaviour is always as expected.
196+
193197

194198
Logging a set of nested scalar values
195199
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -238,9 +242,9 @@ log three consecutive image frames from a ``CartPole`` environment, do the follo
238242
env.reset()
239243
logger.log_value("some_images", value=env.render(), reduce=None, clear_on_reduce=True)
240244
env.step(0)
241-
logger.log_value("some_images", value=env.render())
245+
logger.log_value("some_images", value=env.render(), reduce=None, clear_on_reduce=True)
242246
env.step(1)
243-
logger.log_value("some_images", value=env.render())
247+
logger.log_value("some_images", value=env.render(), reduce=None, clear_on_reduce=True)
244248

245249
Timers
246250
~~~~~~
@@ -296,7 +300,7 @@ Set ``clear_on_reduce=False``, which is the default, if you want the count to ac
296300
logger = MetricsLogger()
297301

298302
logger.log_value("my_counter", 50, reduce="sum", window=None)
299-
logger.log_value("my_counter", 25)
303+
logger.log_value("my_counter", 25, reduce="sum", window=None)
300304
logger.peek("my_counter") # expect: 75
301305

302306
# Even if your logger gets "reduced" from time to time, the counter keeps increasing
@@ -306,7 +310,7 @@ Set ``clear_on_reduce=False``, which is the default, if you want the count to ac
306310

307311
# To clear the sum after each "reduce" event, set `clear_on_reduce=True`:
308312
logger.log_value("my_temp_counter", 50, reduce="sum", window=None, clear_on_reduce=True)
309-
logger.log_value("my_temp_counter", 25)
313+
logger.log_value("my_temp_counter", 25, reduce="sum", window=None, clear_on_reduce=True)
310314
logger.peek("my_counter") # expect: 75
311315
logger.reduce()
312316
logger.peek("my_counter") # expect: 0 (upon reduction, all values are cleared)
@@ -323,8 +327,7 @@ on each ``reduce()`` operation.
323327
The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` automatically compiles an extra key for each such metric, adding the suffix ``_throughput``
324328
to the original key and assigning it the value for the throughput per second.
325329

326-
You can use the :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.peek` method with the call argument ``throughput=True``
327-
to access the throughput value. For example:
330+
You can use the :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.peek` method to access the throughput value by passing the ``throughput=True`` flag.
328331

329332
.. testcode::
330333

@@ -337,13 +340,39 @@ to access the throughput value. For example:
337340
logger.log_value("lifetime_count", 5, reduce="sum", with_throughput=True)
338341

339342
# RLlib triggers a new throughput computation at each `reduce()` call
340-
logger.reduce()
341343
time.sleep(1.0)
342344

343345
# Expect the first call to return NaN because we don't have a proper start time for the time delta.
344346
# From the second call on, expect a value of roughly 5/sec.
345347
print(logger.peek("lifetime_count", throughput=True))
346348

349+
logger.log_value("lifetime_count", 5, reduce="sum", with_throughput=True)
350+
# Expect the throughput to be roughly 10/sec now.
351+
print(logger.peek("lifetime_count", throughput=True))
352+
353+
# You can also get a dict of all throughputs at once:
354+
print(logger.peek(throughput=True))
355+
356+
357+
Measuring throughputs with MetricsLogger.log_time()
358+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
359+
360+
You can also use the :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.log_time` method to measure throughputs.
361+
362+
.. testcode::
363+
364+
import time
365+
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
366+
367+
logger = MetricsLogger()
368+
369+
for _ in range(3):
370+
with logger.log_time("my_block_to_be_timed", with_throughput=True):
371+
time.sleep(1.0)
372+
373+
# Expect the throughput to be roughly 1.0/sec.
374+
print(logger.peek("my_block_to_be_timed", throughput=True))
375+
347376

348377
Example 1: How to use MetricsLogger in EnvRunner callbacks
349378
----------------------------------------------------------

rllib/BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,6 +2896,27 @@ py_test(
28962896
],
28972897
)
28982898

2899+
# Test metrics (metrics logger, stats)
2900+
py_test(
2901+
name = "test_metrics_logger",
2902+
size = "small",
2903+
srcs = ["utils/metrics/tests/test_metrics_logger.py"],
2904+
tags = [
2905+
"team:rllib",
2906+
"utils",
2907+
],
2908+
)
2909+
2910+
py_test(
2911+
name = "test_stats",
2912+
size = "small",
2913+
srcs = ["utils/metrics/tests/test_stats.py"],
2914+
tags = [
2915+
"team:rllib",
2916+
"utils",
2917+
],
2918+
)
2919+
28992920
# @OldAPIStack
29002921
py_test(
29012922
name = "test_value_predictions",

rllib/algorithms/algorithm.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@
159159
)
160160
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
161161
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
162-
from ray.rllib.utils.metrics.stats import Stats
163162
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer, ReplayBuffer
164163
from ray.rllib.utils.runners.runner_group import RunnerGroup
165164
from ray.rllib.utils.serialization import deserialize_type, NOT_SERIALIZABLE
@@ -484,7 +483,7 @@ def __init__(
484483
# The Algorithm's `MetricsLogger` object to collect stats from all its
485484
# components (including timers, counters and other stats in its own
486485
# `training_step()` and other methods) as well as custom callbacks.
487-
self.metrics = MetricsLogger()
486+
self.metrics = MetricsLogger(root=True)
488487

489488
# Create a default logger creator if no logger_creator is specified
490489
if logger_creator is None:
@@ -1139,9 +1138,8 @@ def evaluate_offline(self):
11391138
# Evaluate with fixed duration.
11401139
self._evaluate_offline_with_fixed_duration()
11411140
# Reduce the evaluation results.
1142-
eval_results = self.metrics.reduce(
1143-
key=(EVALUATION_RESULTS, OFFLINE_EVAL_RUNNER_RESULTS),
1144-
return_stats_obj=False,
1141+
eval_results = self.metrics.peek(
1142+
("EVALUATION_RESULTS", "OFFLINE_EVAL_RUNNER_RESULTS"), default={}
11451143
)
11461144

11471145
# Trigger `on_evaluate_offline_end` callback.
@@ -1292,9 +1290,11 @@ def evaluate(
12921290
eval_results = {}
12931291

12941292
if self.config.enable_env_runner_and_connector_v2:
1295-
eval_results = self.metrics.reduce(
1296-
key=EVALUATION_RESULTS, return_stats_obj=False
1297-
)
1293+
eval_results = self.metrics.peek(key=EVALUATION_RESULTS, default={})
1294+
if log_once("no_eval_results") and not eval_results:
1295+
logger.warning(
1296+
"No evaluation results found for this iteration. This can happen if the evaluation worker(s) is/are not healthy."
1297+
)
12981298
else:
12991299
eval_results = {ENV_RUNNER_RESULTS: eval_results}
13001300
eval_results[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps
@@ -3382,9 +3382,9 @@ def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
33823382
key=AGGREGATOR_ACTOR_RESULTS,
33833383
)
33843384

3385-
# Only here (at the end of the iteration), reduce the results into a single
3386-
# result dict.
3387-
return self.metrics.reduce(), train_iter_ctx
3385+
# Only here (at the end of the iteration), compile the results into a single result dict.
3386+
# Calling compile here reduces the metrics into single values and adds throughputs to the results where applicable.
3387+
return self.metrics.compile(), train_iter_ctx
33883388

33893389
def _run_one_offline_evaluation(self):
33903390
"""Runs offline evaluation step via `self.offline_evaluate()` and handling runner
@@ -3606,26 +3606,7 @@ def _compile_iteration_results(self, *, train_results, eval_results):
36063606
),
36073607
}
36083608

3609-
# Compile all throughput stats.
3610-
throughputs = {}
3611-
3612-
def _reduce(p, s):
3613-
if isinstance(s, Stats):
3614-
ret = s.peek()
3615-
_throughput = s.peek(throughput=True)
3616-
if _throughput is not None:
3617-
_curr = throughputs
3618-
for k in p[:-1]:
3619-
_curr = _curr.setdefault(k, {})
3620-
_curr[p[-1] + "_throughput"] = _throughput
3621-
else:
3622-
ret = s
3623-
return ret
3624-
3625-
# Resolve all `Stats` leafs by peeking (get their reduced values).
3626-
all_results = tree.map_structure_with_path(_reduce, results)
3627-
deep_update(all_results, throughputs, new_keys_allowed=True)
3628-
return all_results
3609+
return results
36293610

36303611
def __repr__(self):
36313612
if self.config.enable_rl_module_and_learner:
@@ -4466,6 +4447,7 @@ def should_stop(self, results):
44664447
min_t = self.algo.config.min_time_s_per_iteration
44674448
min_sample_ts = self.algo.config.min_sample_timesteps_per_iteration
44684449
min_train_ts = self.algo.config.min_train_timesteps_per_iteration
4450+
44694451
# Repeat if not enough time has passed or if not enough
44704452
# env|train timesteps have been processed (or these min
44714453
# values are not provided by the user).

rllib/algorithms/impala/impala.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def _training_step_old_api_stack(self):
931931

932932
# With a training step done, try to bring any aggregators back to life
933933
# if necessary.
934-
# Aggregation workers are stateless, so we do not need to restore any
934+
# AggregatorActor are stateless, so we do not need to restore any
935935
# state here.
936936
if self._aggregator_actor_manager:
937937
self._aggregator_actor_manager.probe_unhealthy_actors(

rllib/algorithms/sac/torch/sac_torch_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ def compute_loss_for_module(
211211
POLICY_LOSS_KEY: actor_loss,
212212
QF_LOSS_KEY: critic_loss,
213213
"alpha_loss": alpha_loss,
214-
"alpha_value": alpha,
215-
"log_alpha_value": torch.log(alpha),
214+
"alpha_value": alpha[0],
215+
"log_alpha_value": torch.log(alpha)[0],
216216
"target_entropy": self.target_entropy[module_id],
217217
LOGPS_KEY: torch.mean(fwd_out["logp_resampled"]),
218218
QF_MEAN_KEY: torch.mean(fwd_out["q_curr"]),

rllib/core/learner/learner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,18 +1645,21 @@ def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
16451645
key=(mid, NUM_MODULE_STEPS_TRAINED_LIFETIME),
16461646
value=module_batch_size,
16471647
reduce="sum",
1648+
with_throughput=True,
16481649
)
16491650
# Log module steps (sum of all modules).
16501651
self.metrics.log_value(
16511652
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED),
16521653
value=module_batch_size,
16531654
reduce="sum",
16541655
clear_on_reduce=True,
1656+
with_throughput=True,
16551657
)
16561658
self.metrics.log_value(
16571659
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED_LIFETIME),
16581660
value=module_batch_size,
16591661
reduce="sum",
1662+
with_throughput=True,
16601663
)
16611664
# Log env steps (all modules).
16621665
self.metrics.log_value(

rllib/examples/evaluation/custom_evaluation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,7 @@ def custom_eval_function(
154154
algorithm.metrics.merge_and_log_n_dicts(
155155
env_runner_metrics, key=(EVALUATION_RESULTS, ENV_RUNNER_RESULTS)
156156
)
157-
eval_results = algorithm.metrics.reduce(
158-
key=(EVALUATION_RESULTS, ENV_RUNNER_RESULTS)
159-
)
157+
eval_results = algorithm.metrics.peek((EVALUATION_RESULTS, ENV_RUNNER_RESULTS))
160158
# Alternatively, you could manually reduce over the n returned `env_runner_metrics`
161159
# dicts, but this would be much harder as you might not know, which metrics
162160
# to sum up, which ones to average over, etc..

rllib/examples/learners/classes/vpg_torch_learner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,8 @@ def compute_loss_for_module(
5353
self.metrics.log_value(
5454
key=(module_id, f"action_{act}_return_to_go_mean"),
5555
value=ret_to_go,
56-
# Mean over the batch size.
5756
reduce="mean",
58-
window=len(batch[Columns.RETURNS_TO_GO]),
57+
clear_on_reduce=True,
5958
)
6059

6160
return loss

rllib/tuned_examples/impala/heavy_cartpole_impala.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Non-learning, throughput-only benchmark used to tune and test the usage of
2-
# AggregationActors in IMPALA and APPO.
2+
# AggregatorActor in IMPALA and APPO.
33

44
# With the current setup below, 27 EnvRunners (+ 2 eval EnvRunners), 0 Learners
55
# 1 local A10 GPU Learner and 2 Aggregator actors, the achieved training throughput

0 commit comments

Comments
 (0)