Skip to content

Commit 02d4a3a

Browse files
authored
[RLlib] APPO accelerate (vol 18): EnvRunner sync enhancements. (#50918)
1 parent 6692500 commit 02d4a3a

File tree

4 files changed

+108
-51
lines changed

4 files changed

+108
-51
lines changed

rllib/BUILD

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2298,16 +2298,16 @@ py_test(
22982298
)
22992299

23002300
py_test(
2301-
name = "examples/evaluation/custom_evaluation_parallel_to_training",
2301+
name = "examples/evaluation/custom_evaluation_parallel_to_training_10_episodes",
23022302
main = "examples/evaluation/custom_evaluation.py",
23032303
tags = ["team:rllib", "exclusive", "examples"],
23042304
size = "medium",
23052305
srcs = ["examples/evaluation/custom_evaluation.py"],
2306-
args = ["--enable-new-api-stack", "--as-test", "--framework=torch", "--stop-reward=0.75", "--evaluation-parallel-to-training", "--num-cpus=5"]
2306+
args = ["--enable-new-api-stack", "--as-test", "--stop-reward=0.75", "--evaluation-parallel-to-training", "--num-cpus=5", "--evaluation-duration=10", "--evaluation-duration-unit=episodes"]
23072307
)
23082308

23092309
py_test(
2310-
name = "examples/evaluation/evaluation_parallel_to_training_duration_auto_torch",
2310+
name = "examples/evaluation/evaluation_parallel_to_training_duration_auto",
23112311
main = "examples/evaluation/evaluation_parallel_to_training.py",
23122312
tags = ["team:rllib", "exclusive", "examples"],
23132313
size = "medium",
@@ -2316,7 +2316,7 @@ py_test(
23162316
)
23172317

23182318
py_test(
2319-
name = "examples/evaluation/evaluation_parallel_to_training_multi_agent_duration_auto_torch",
2319+
name = "examples/evaluation/evaluation_parallel_to_training_multi_agent_duration_auto",
23202320
main = "examples/evaluation/evaluation_parallel_to_training.py",
23212321
tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core"],
23222322
size = "large",
@@ -2343,7 +2343,7 @@ py_test(
23432343
)
23442344

23452345
py_test(
2346-
name = "examples/evaluation/evaluation_parallel_to_training_13_episodes_torch",
2346+
name = "examples/evaluation/evaluation_parallel_to_training_13_episodes",
23472347
main = "examples/evaluation/evaluation_parallel_to_training.py",
23482348
tags = ["team:rllib", "exclusive", "examples"],
23492349
size = "medium",
@@ -2352,7 +2352,7 @@ py_test(
23522352
)
23532353

23542354
py_test(
2355-
name = "examples/evaluation/evaluation_parallel_to_training_multi_agent_10_episodes_torch",
2355+
name = "examples/evaluation/evaluation_parallel_to_training_multi_agent_10_episodes",
23562356
main = "examples/evaluation/evaluation_parallel_to_training.py",
23572357
tags = ["team:rllib", "exclusive", "examples"],
23582358
size = "medium",
@@ -2362,17 +2362,17 @@ py_test(
23622362

23632363
# @OldAPIStack
23642364
py_test(
2365-
name = "examples/evaluation/evaluation_parallel_to_training_duration_auto_torch_old_api_stack",
2365+
name = "examples/evaluation/evaluation_parallel_to_training_duration_auto_old_api_stack",
23662366
main = "examples/evaluation/evaluation_parallel_to_training.py",
23672367
tags = ["team:rllib", "exclusive", "examples"],
23682368
size = "medium",
23692369
srcs = ["examples/evaluation/evaluation_parallel_to_training.py"],
2370-
args = ["--as-test", "--evaluation-parallel-to-training", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=auto"]
2370+
args = ["--as-test", "--evaluation-parallel-to-training", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=auto", "--evaluation-duration-unit=timesteps"]
23712371
)
23722372

23732373
# @OldAPIStack
23742374
py_test(
2375-
name = "examples/evaluation/evaluation_parallel_to_training_211_ts_torch_old_api_stack",
2375+
name = "examples/evaluation/evaluation_parallel_to_training_211_ts_old_api_stack",
23762376
main = "examples/evaluation/evaluation_parallel_to_training.py",
23772377
tags = ["team:rllib", "exclusive", "examples"],
23782378
size = "medium",

rllib/algorithms/algorithm.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,10 +1041,23 @@ def evaluate(
10411041

10421042
# Sync weights to the evaluation EnvRunners.
10431043
if self.eval_env_runner_group is not None:
1044+
if self.config.enable_env_runner_and_connector_v2:
1045+
if (
1046+
self.env_runner_group is not None
1047+
and self.env_runner_group.healthy_env_runner_ids()
1048+
):
1049+
# TODO (sven): Replace this with a new ActorManager API:
1050+
# try_remote_request_till_success("get_state") -> tuple(int,
1051+
# remoteresult)
1052+
weights_src = self.env_runner_group._worker_manager._actors[
1053+
self.env_runner_group.healthy_env_runner_ids()[0]
1054+
]
1055+
else:
1056+
weights_src = self.learner_group
1057+
else:
1058+
weights_src = self.env_runner_group.local_env_runner
10441059
self.eval_env_runner_group.sync_weights(
1045-
from_worker_or_learner_group=self.learner_group
1046-
if self.config.enable_env_runner_and_connector_v2
1047-
else self.env_runner_group.local_env_runner,
1060+
from_worker_or_learner_group=weights_src,
10481061
inference_only=True,
10491062
)
10501063

@@ -1444,7 +1457,7 @@ def _evaluate_with_fixed_duration(self):
14441457

14451458
# Remote function used on healthy EnvRunners to sample, get metrics, and
14461459
# step counts.
1447-
def _env_runner_remote(worker, num, round, iter):
1460+
def _env_runner_remote(worker, num, round, iter, _force_reset):
14481461
# Sample AND get_metrics, but only return metrics (and steps actually taken)
14491462
# to save time. Also return the iteration to check, whether we should
14501463
# discard and outdated result (from a slow worker).
@@ -1453,7 +1466,7 @@ def _env_runner_remote(worker, num, round, iter):
14531466
num[worker.worker_index] if unit == "timesteps" else None
14541467
),
14551468
num_episodes=(num[worker.worker_index] if unit == "episodes" else None),
1456-
force_reset=force_reset and round == 0,
1469+
force_reset=_force_reset and round == 0,
14571470
)
14581471
metrics = worker.get_metrics()
14591472
env_steps = sum(e.env_steps() for e in episodes)
@@ -1491,7 +1504,11 @@ def _env_runner_remote(worker, num, round, iter):
14911504
]
14921505
self.eval_env_runner_group.foreach_env_runner_async(
14931506
func=functools.partial(
1494-
_env_runner_remote, num=_num, round=_round, iter=algo_iteration
1507+
_env_runner_remote,
1508+
num=_num,
1509+
round=_round,
1510+
iter=algo_iteration,
1511+
_force_reset=force_reset,
14951512
),
14961513
)
14971514
results = self.eval_env_runner_group.fetch_ready_async_reqs(

rllib/env/env_runner_group.py

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -495,13 +495,9 @@ def sync_env_runner_states(
495495
config.num_env_runners or 1
496496
)
497497

498-
# Update the rl_module component of the EnvRunner states, if necessary:
499-
if rl_module_state:
500-
env_runner_states.update(rl_module_state)
501-
502498
# If we do NOT want remote EnvRunners to get their Connector states updated,
503-
# only update the local worker here (with all state components) and then remove
504-
# the connector components.
499+
# only update the local worker here (with all state components, except the model
500+
# weights) and then remove the connector components.
505501
if not config.update_worker_filter_stats:
506502
self.local_env_runner.set_state(env_runner_states)
507503
env_runner_states.pop(COMPONENT_ENV_TO_MODULE_CONNECTOR, None)
@@ -510,18 +506,24 @@ def sync_env_runner_states(
510506
# If there are components in the state left -> Update remote workers with these
511507
# state components (and maybe the local worker, if it hasn't been updated yet).
512508
if env_runner_states:
513-
# Put the state dictionary into Ray's object store to avoid having to make n
514-
# pickled copies of the state dict.
515-
ref_env_runner_states = ray.put(env_runner_states)
516-
517-
def _update(_env_runner: EnvRunner) -> None:
518-
_env_runner.set_state(ray.get(ref_env_runner_states))
509+
# Update the local EnvRunner, but NOT with the weights. If used at all for
510+
# evaluation (through the user calling `self.evaluate`), RLlib would update
511+
# the weights up front either way.
512+
if config.update_worker_filter_stats:
513+
self.local_env_runner.set_state(env_runner_states)
514+
515+
# Send the model weights only to remote EnvRunners.
516+
# In case the local EnvRunner is ever needed for evaluation,
517+
# RLlib updates its weight right before such an eval step.
518+
if rl_module_state:
519+
env_runner_states.update(rl_module_state)
519520

520521
# Broadcast updated states back to all workers.
521522
self.foreach_env_runner(
522-
_update,
523+
"set_state", # Call the `set_state()` remote method.
524+
kwargs=dict(state=env_runner_states),
523525
remote_worker_ids=env_runner_indices_to_update,
524-
local_env_runner=config.update_worker_filter_stats,
526+
local_env_runner=False,
525527
timeout_seconds=0.0, # This is a state update -> Fire-and-forget.
526528
)
527529

@@ -581,18 +583,35 @@ def sync_weights(
581583
if policies is not None
582584
else [COMPONENT_RL_MODULE]
583585
)
584-
# LearnerGroup has-a Learner has-a RLModule.
586+
# LearnerGroup has-a Learner, which has-a RLModule.
585587
if isinstance(weights_src, LearnerGroup):
586588
rl_module_state = weights_src.get_state(
587589
components=[COMPONENT_LEARNER + "/" + m for m in modules],
588590
inference_only=inference_only,
589591
)[COMPONENT_LEARNER]
590-
# EnvRunner has-a RLModule.
592+
# EnvRunner (new API stack).
591593
elif self._remote_config.enable_env_runner_and_connector_v2:
592-
rl_module_state = weights_src.get_state(
593-
components=modules,
594-
inference_only=inference_only,
595-
)
594+
# EnvRunner (remote) has-a RLModule.
595+
# TODO (sven): Replace this with a new ActorManager API:
596+
# try_remote_request_till_success("get_state") -> tuple(int,
597+
# remoteresult)
598+
# `weights_src` could be the ActorManager, then. Then RLlib would know
599+
# that it has to ping the manager to try all healthy actors until the
600+
# first returns something.
601+
if isinstance(weights_src, ray.actor.ActorHandle):
602+
rl_module_state = ray.get(
603+
weights_src.get_state.remote(
604+
components=modules,
605+
inference_only=inference_only,
606+
)
607+
)
608+
# EnvRunner (local) has-a RLModule.
609+
else:
610+
rl_module_state = weights_src.get_state(
611+
components=modules,
612+
inference_only=inference_only,
613+
)
614+
# RolloutWorker (old API stack).
596615
else:
597616
rl_module_state = weights_src.get_weights(
598617
policies=policies,
@@ -613,22 +632,28 @@ def sync_weights(
613632
# copies of the weights dict for each worker.
614633
rl_module_state_ref = ray.put(rl_module_state)
615634

616-
def _set_weights(env_runner):
617-
env_runner.set_state(ray.get(rl_module_state_ref))
635+
# Sync to specified remote workers in this EnvRunnerGroup.
636+
self.foreach_env_runner(
637+
func="set_state",
638+
kwargs=dict(state=rl_module_state_ref),
639+
local_env_runner=False, # Do not sync back to local worker.
640+
remote_worker_ids=to_worker_indices,
641+
timeout_seconds=timeout_seconds,
642+
)
618643

619644
else:
620645
rl_module_state_ref = ray.put(rl_module_state)
621646

622647
def _set_weights(env_runner):
623648
env_runner.set_weights(ray.get(rl_module_state_ref), global_vars)
624649

625-
# Sync to specified remote workers in this EnvRunnerGroup.
626-
self.foreach_env_runner(
627-
func=_set_weights,
628-
local_env_runner=False, # Do not sync back to local worker.
629-
remote_worker_ids=to_worker_indices,
630-
timeout_seconds=timeout_seconds,
631-
)
650+
# Sync to specified remote workers in this EnvRunnerGroup.
651+
self.foreach_env_runner(
652+
func=_set_weights,
653+
local_env_runner=False, # Do not sync back to local worker.
654+
remote_worker_ids=to_worker_indices,
655+
timeout_seconds=timeout_seconds,
656+
)
632657

633658
# If `from_worker_or_learner_group` is provided, also sync to this
634659
# EnvRunnerGroup's local worker.
@@ -716,8 +741,11 @@ def stop(self) -> None:
716741

717742
def foreach_env_runner(
718743
self,
719-
func: Callable[[EnvRunner], T],
744+
func: Union[
745+
Callable[[EnvRunner], T], List[Callable[[EnvRunner], T]], str, List[str]
746+
],
720747
*,
748+
kwargs=None,
721749
local_env_runner: bool = True,
722750
healthy_only: bool = True,
723751
remote_worker_ids: List[int] = None,
@@ -759,13 +787,18 @@ def foreach_env_runner(
759787

760788
local_result = []
761789
if local_env_runner and self.local_env_runner is not None:
762-
local_result = [func(self.local_env_runner)]
790+
assert kwargs is None
791+
if isinstance(func, str):
792+
local_result = [getattr(self.local_env_runner, func)]
793+
else:
794+
local_result = [func(self.local_env_runner)]
763795

764796
if not self._worker_manager.actor_ids():
765797
return local_result
766798

767799
remote_results = self._worker_manager.foreach_actor(
768800
func,
801+
kwargs=kwargs,
769802
healthy_only=healthy_only,
770803
remote_actor_ids=remote_worker_ids,
771804
timeout_seconds=timeout_seconds,
@@ -782,18 +815,24 @@ def foreach_env_runner(
782815

783816
return local_result + remote_results
784817

818+
# TODO (sven): Deprecate this API. Users can lookup the "worker index" from the
819+
# EnvRunner object directly through `self.worker_index` (besides many other useful
820+
# properties, like `in_evaluation`, `num_env_runners`, etc..).
785821
def foreach_env_runner_with_id(
786822
self,
787-
func: Callable[[int, EnvRunner], T],
823+
func: Union[
824+
Callable[[int, EnvRunner], T],
825+
List[Callable[[int, EnvRunner], T]],
826+
str,
827+
List[str],
828+
],
788829
*,
789830
local_env_runner: bool = True,
790831
healthy_only: bool = True,
791832
remote_worker_ids: List[int] = None,
792833
timeout_seconds: Optional[float] = None,
793834
return_obj_refs: bool = False,
794835
mark_healthy: bool = False,
795-
# Deprecated args.
796-
local_worker=DEPRECATED_VALUE,
797836
) -> List[T]:
798837
"""Calls the given function with each EnvRunner and its ID as its arguments.
799838
@@ -850,7 +889,9 @@ def foreach_env_runner_with_id(
850889

851890
def foreach_env_runner_async(
852891
self,
853-
func: Union[Callable[[EnvRunner], T], str],
892+
func: Union[
893+
Callable[[EnvRunner], T], List[Callable[[EnvRunner], T]], str, List[str]
894+
],
854895
*,
855896
healthy_only: bool = True,
856897
remote_worker_ids: List[int] = None,

rllib/examples/evaluation/evaluation_parallel_to_training.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@
9292
parser.set_defaults(
9393
evaluation_num_env_runners=2,
9494
evaluation_interval=1,
95-
evaluation_duration_unit="timesteps",
9695
)
9796

9897

0 commit comments

Comments
 (0)