@@ -495,13 +495,9 @@ def sync_env_runner_states(
495
495
config .num_env_runners or 1
496
496
)
497
497
498
- # Update the rl_module component of the EnvRunner states, if necessary:
499
- if rl_module_state :
500
- env_runner_states .update (rl_module_state )
501
-
502
498
# If we do NOT want remote EnvRunners to get their Connector states updated,
503
- # only update the local worker here (with all state components) and then remove
504
- # the connector components.
499
+ # only update the local worker here (with all state components, except the model
500
+ # weights) and then remove the connector components.
505
501
if not config .update_worker_filter_stats :
506
502
self .local_env_runner .set_state (env_runner_states )
507
503
env_runner_states .pop (COMPONENT_ENV_TO_MODULE_CONNECTOR , None )
@@ -510,18 +506,24 @@ def sync_env_runner_states(
510
506
# If there are components in the state left -> Update remote workers with these
511
507
# state components (and maybe the local worker, if it hasn't been updated yet).
512
508
if env_runner_states :
513
- # Put the state dictionary into Ray's object store to avoid having to make n
514
- # pickled copies of the state dict.
515
- ref_env_runner_states = ray .put (env_runner_states )
516
-
517
- def _update (_env_runner : EnvRunner ) -> None :
518
- _env_runner .set_state (ray .get (ref_env_runner_states ))
509
+ # Update the local EnvRunner, but NOT with the weights. If used at all for
510
+ # evaluation (through the user calling `self.evaluate`), RLlib would update
511
+ # the weights up front either way.
512
+ if config .update_worker_filter_stats :
513
+ self .local_env_runner .set_state (env_runner_states )
514
+
515
+ # Send the model weights only to remote EnvRunners.
516
+ # In case the local EnvRunner is ever needed for evaluation,
517
+ # RLlib updates its weight right before such an eval step.
518
+ if rl_module_state :
519
+ env_runner_states .update (rl_module_state )
519
520
520
521
# Broadcast updated states back to all workers.
521
522
self .foreach_env_runner (
522
- _update ,
523
+ "set_state" , # Call the `set_state()` remote method.
524
+ kwargs = dict (state = env_runner_states ),
523
525
remote_worker_ids = env_runner_indices_to_update ,
524
- local_env_runner = config . update_worker_filter_stats ,
526
+ local_env_runner = False ,
525
527
timeout_seconds = 0.0 , # This is a state update -> Fire-and-forget.
526
528
)
527
529
@@ -581,18 +583,35 @@ def sync_weights(
581
583
if policies is not None
582
584
else [COMPONENT_RL_MODULE ]
583
585
)
584
- # LearnerGroup has-a Learner has-a RLModule.
586
+ # LearnerGroup has-a Learner, which has-a RLModule.
585
587
if isinstance (weights_src , LearnerGroup ):
586
588
rl_module_state = weights_src .get_state (
587
589
components = [COMPONENT_LEARNER + "/" + m for m in modules ],
588
590
inference_only = inference_only ,
589
591
)[COMPONENT_LEARNER ]
590
- # EnvRunner has-a RLModule .
592
+ # EnvRunner (new API stack) .
591
593
elif self ._remote_config .enable_env_runner_and_connector_v2 :
592
- rl_module_state = weights_src .get_state (
593
- components = modules ,
594
- inference_only = inference_only ,
595
- )
594
+ # EnvRunner (remote) has-a RLModule.
595
+ # TODO (sven): Replace this with a new ActorManager API:
596
+ # try_remote_request_till_success("get_state") -> tuple(int,
597
+ # remoteresult)
598
+ # `weights_src` could be the ActorManager, then. Then RLlib would know
599
+ # that it has to ping the manager to try all healthy actors until the
600
+ # first returns something.
601
+ if isinstance (weights_src , ray .actor .ActorHandle ):
602
+ rl_module_state = ray .get (
603
+ weights_src .get_state .remote (
604
+ components = modules ,
605
+ inference_only = inference_only ,
606
+ )
607
+ )
608
+ # EnvRunner (local) has-a RLModule.
609
+ else :
610
+ rl_module_state = weights_src .get_state (
611
+ components = modules ,
612
+ inference_only = inference_only ,
613
+ )
614
+ # RolloutWorker (old API stack).
596
615
else :
597
616
rl_module_state = weights_src .get_weights (
598
617
policies = policies ,
@@ -613,22 +632,28 @@ def sync_weights(
613
632
# copies of the weights dict for each worker.
614
633
rl_module_state_ref = ray .put (rl_module_state )
615
634
616
- def _set_weights (env_runner ):
617
- env_runner .set_state (ray .get (rl_module_state_ref ))
635
+ # Sync to specified remote workers in this EnvRunnerGroup.
636
+ self .foreach_env_runner (
637
+ func = "set_state" ,
638
+ kwargs = dict (state = rl_module_state_ref ),
639
+ local_env_runner = False , # Do not sync back to local worker.
640
+ remote_worker_ids = to_worker_indices ,
641
+ timeout_seconds = timeout_seconds ,
642
+ )
618
643
619
644
else :
620
645
rl_module_state_ref = ray .put (rl_module_state )
621
646
622
647
def _set_weights (env_runner ):
623
648
env_runner .set_weights (ray .get (rl_module_state_ref ), global_vars )
624
649
625
- # Sync to specified remote workers in this EnvRunnerGroup.
626
- self .foreach_env_runner (
627
- func = _set_weights ,
628
- local_env_runner = False , # Do not sync back to local worker.
629
- remote_worker_ids = to_worker_indices ,
630
- timeout_seconds = timeout_seconds ,
631
- )
650
+ # Sync to specified remote workers in this EnvRunnerGroup.
651
+ self .foreach_env_runner (
652
+ func = _set_weights ,
653
+ local_env_runner = False , # Do not sync back to local worker.
654
+ remote_worker_ids = to_worker_indices ,
655
+ timeout_seconds = timeout_seconds ,
656
+ )
632
657
633
658
# If `from_worker_or_learner_group` is provided, also sync to this
634
659
# EnvRunnerGroup's local worker.
@@ -716,8 +741,11 @@ def stop(self) -> None:
716
741
717
742
def foreach_env_runner (
718
743
self ,
719
- func : Callable [[EnvRunner ], T ],
744
+ func : Union [
745
+ Callable [[EnvRunner ], T ], List [Callable [[EnvRunner ], T ]], str , List [str ]
746
+ ],
720
747
* ,
748
+ kwargs = None ,
721
749
local_env_runner : bool = True ,
722
750
healthy_only : bool = True ,
723
751
remote_worker_ids : List [int ] = None ,
@@ -759,13 +787,18 @@ def foreach_env_runner(
759
787
760
788
local_result = []
761
789
if local_env_runner and self .local_env_runner is not None :
762
- local_result = [func (self .local_env_runner )]
790
+ assert kwargs is None
791
+ if isinstance (func , str ):
792
+ local_result = [getattr (self .local_env_runner , func )]
793
+ else :
794
+ local_result = [func (self .local_env_runner )]
763
795
764
796
if not self ._worker_manager .actor_ids ():
765
797
return local_result
766
798
767
799
remote_results = self ._worker_manager .foreach_actor (
768
800
func ,
801
+ kwargs = kwargs ,
769
802
healthy_only = healthy_only ,
770
803
remote_actor_ids = remote_worker_ids ,
771
804
timeout_seconds = timeout_seconds ,
@@ -782,18 +815,24 @@ def foreach_env_runner(
782
815
783
816
return local_result + remote_results
784
817
818
+ # TODO (sven): Deprecate this API. Users can lookup the "worker index" from the
819
+ # EnvRunner object directly through `self.worker_index` (besides many other useful
820
+ # properties, like `in_evaluation`, `num_env_runners`, etc..).
785
821
def foreach_env_runner_with_id (
786
822
self ,
787
- func : Callable [[int , EnvRunner ], T ],
823
+ func : Union [
824
+ Callable [[int , EnvRunner ], T ],
825
+ List [Callable [[int , EnvRunner ], T ]],
826
+ str ,
827
+ List [str ],
828
+ ],
788
829
* ,
789
830
local_env_runner : bool = True ,
790
831
healthy_only : bool = True ,
791
832
remote_worker_ids : List [int ] = None ,
792
833
timeout_seconds : Optional [float ] = None ,
793
834
return_obj_refs : bool = False ,
794
835
mark_healthy : bool = False ,
795
- # Deprecated args.
796
- local_worker = DEPRECATED_VALUE ,
797
836
) -> List [T ]:
798
837
"""Calls the given function with each EnvRunner and its ID as its arguments.
799
838
@@ -850,7 +889,9 @@ def foreach_env_runner_with_id(
850
889
851
890
def foreach_env_runner_async (
852
891
self ,
853
- func : Union [Callable [[EnvRunner ], T ], str ],
892
+ func : Union [
893
+ Callable [[EnvRunner ], T ], List [Callable [[EnvRunner ], T ]], str , List [str ]
894
+ ],
854
895
* ,
855
896
healthy_only : bool = True ,
856
897
remote_worker_ids : List [int ] = None ,
0 commit comments