Skip to content

Commit 1bffb1f

Browse files
authored
[RLlib] Enhance ConnectorV2 ObservationPreprocessor APIs (add multi-agent support; add episode arg). (#54209)
1 parent c07376b commit 1bffb1f

File tree

9 files changed

+467
-305
lines changed

9 files changed

+467
-305
lines changed

doc/source/rllib/rllib-examples.rst

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,21 @@ Connectors
134134
This type of filtering can improve learning stability in environments with highly variable state magnitudes
135135
by scaling observations to a normalized range.
136136

137-
- `Multi-agent connector mapping global observations to different per-agent/policy observations <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/multi_agent_with_different_observation_spaces.py>`__:
138-
A connector example showing how to map from a global, multi-agent observation space to n individual, per-agent, per-module observation spaces.
137+
- `Multi-agent observation preprocessor enhancing non-Markovian observations to Markovian ones <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/multi_agent_observation_preprocessor.py>`__:
138+
A multi-agent preprocessor enhances the per-agent observations of a multi-agent env, which by themselves are non-Markovian,
139+
partial observations and converts them into Markovian observations by adding information from
140+
the respective other agent. A policy can only be trained optimally through this additional information.
139141

140142
- `Prev-actions, prev-rewards connector <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/prev_actions_prev_rewards.py>`__:
141143
Augments observations with previous actions and rewards, giving the agent a short-term memory of past events, which can improve
142144
decision-making in partially observable or sequentially dependent tasks.
143145

146+
- `Single-agent observation preprocessor <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/single_agent_observation_preprocessor.py>`__:
147+
A connector alters the CartPole-v1 environment observations from the Markovian 4-tuple (x-pos,
148+
angular-pos, x-velocity, angular-velocity) to a non-Markovian, simpler 2-tuple (only
149+
x-pos and angular-pos). The resulting problem can only be solved through a
150+
memory/stateful model, for example an LSTM.
151+
144152

145153
Curiosity
146154
+++++++++
@@ -308,8 +316,10 @@ Multi-agent RL
308316
a hand-coded random policy while another agent trains with PPO. This example highlights integrating static and dynamic policies,
309317
suitable for environments with a mix of fixed-strategy and adaptive agents.
310318

311-
- `Different spaces for agents <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent/different_spaces_for_agents.py>`__:
319+
- `Different observation- and action spaces for different agents <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent/different_spaces_for_agents.py>`__:
312320
Configures agents with differing observation and action spaces within the same environment, showcasing RLlib's support for heterogeneous agents with varying space requirements in a single multi-agent environment.
321+
Another example, which also makes use of connectors, and that covers the same topic, agents having different spaces, can be found
322+
`here <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/multi_agent_observation_preprocessor.py>`__.
313323

314324
- `Grouped agents, two-step game <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent/two_step_game_with_grouped_agents.py>`__:
315325
Implements a multi-agent, grouped setup within a two-step game environment from the `QMIX paper <https://arxiv.org/pdf/1803.11485.pdf>`__.

rllib/BUILD

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3857,14 +3857,31 @@ py_test(
38573857
)
38583858

38593859
py_test(
3860-
name = "examples/connectors/multi_agent_with_different_observation_spaces",
3860+
name = "examples/connectors/multi_agent_observation_preprocessor",
38613861
size = "medium",
3862-
srcs = ["examples/connectors/multi_agent_with_different_observation_spaces.py"],
3862+
srcs = ["examples/connectors/multi_agent_observation_preprocessor.py"],
38633863
args = [
38643864
"--enable-new-api-stack",
38653865
"--num-agents=2",
3866+
"--algo=PPO",
3867+
],
3868+
main = "examples/connectors/multi_agent_observation_preprocessor.py",
3869+
tags = [
3870+
"examples",
3871+
"exclusive",
3872+
"team:rllib",
3873+
],
3874+
)
3875+
3876+
py_test(
3877+
name = "examples/connectors/single_agent_observation_preprocessor",
3878+
size = "medium",
3879+
srcs = ["examples/connectors/single_agent_observation_preprocessor.py"],
3880+
args = [
3881+
"--enable-new-api-stack",
3882+
"--algo=PPO",
38663883
],
3867-
main = "examples/connectors/multi_agent_with_different_observation_spaces.py",
3884+
main = "examples/connectors/single_agent_observation_preprocessor.py",
38683885
tags = [
38693886
"examples",
38703887
"exclusive",

rllib/connectors/env_to_module/observation_preprocessor.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,24 @@
55

66
from ray.rllib.connectors.connector_v2 import ConnectorV2
77
from ray.rllib.core.rl_module.rl_module import RLModule
8+
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
9+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
810
from ray.rllib.utils.annotations import override
911
from ray.rllib.utils.typing import EpisodeType
1012
from ray.util.annotations import PublicAPI
1113

1214

1315
@PublicAPI(stability="alpha")
14-
class ObservationPreprocessor(ConnectorV2, abc.ABC):
15-
"""Env-to-module connector performing one preprocessor step on the last observation.
16+
class SingleAgentObservationPreprocessor(ConnectorV2, abc.ABC):
17+
"""Env-to-module connector preprocessing the most recent single-agent observation.
1618
1719
This is a convenience class that simplifies the writing of few-step preprocessor
1820
connectors.
1921
22+
Note that this class also works in a multi-agent setup, in which case RLlib
23+
separately calls this connector piece with each agents' observation and
24+
`SingleAgentEpisode` object.
25+
2026
Users must implement the `preprocess()` method, which simplifies the usual procedure
2127
of extracting some data from a list of episodes and adding it to the batch to a mere
2228
"old-observation --transform--> return new-observation" step.
@@ -28,23 +34,27 @@ def recompute_output_observation_space(
2834
input_observation_space: gym.Space,
2935
input_action_space: gym.Space,
3036
) -> gym.Space:
31-
# Users should override this method only in case the `ObservationPreprocessor`
32-
# changes the observation space of the pipeline. In this case, return the new
33-
# observation space based on the incoming one (`input_observation_space`).
37+
# Users should override this method only in case the
38+
# `SingleAgentObservationPreprocessor` changes the observation space of the
39+
# pipeline. In this case, return the new observation space based on the
40+
# incoming one (`input_observation_space`).
3441
return super().recompute_output_observation_space(
3542
input_observation_space, input_action_space
3643
)
3744

3845
@abc.abstractmethod
39-
def preprocess(self, observation):
46+
def preprocess(self, observation, episode: SingleAgentEpisode):
4047
"""Override to implement the preprocessing logic.
4148
4249
Args:
4350
observation: A single (non-batched) observation item for a single agent to
44-
be processed by this connector.
51+
be preprocessed by this connector.
52+
episode: The `SingleAgentEpisode` instance, from which `observation` was
53+
taken. You can extract information on the particular AgentID and the
54+
ModuleID through `episode.agent_id` and `episode.module_id`.
4555
4656
Returns:
47-
The new observation after `observation` has been preprocessed.
57+
The new observation for the agent after `observation` has been preprocessed.
4858
"""
4959

5060
@override(ConnectorV2)
@@ -67,14 +77,105 @@ def __call__(
6777

6878
# Process the observation and write the new observation back into the
6979
# episode.
70-
new_observation = self.preprocess(observation=observation)
80+
new_observation = self.preprocess(
81+
observation=observation,
82+
episode=sa_episode,
83+
)
7184
sa_episode.set_observations(at_indices=-1, new_data=new_observation)
7285
# We set the Episode's observation space to ours so that we can safely
7386
# set the last obs to the new value (without causing a space mismatch
7487
# error).
7588
sa_episode.observation_space = self.observation_space
7689

77-
# Leave `batch` as is. RLlib's default connector will automatically
78-
# populate the OBS column therein from the episodes' now transformed
79-
# observations.
90+
# Leave `batch` as is. RLlib's default connector automatically populates
91+
# the OBS column therein from the episodes' now transformed observations.
8092
return batch
93+
94+
95+
@PublicAPI(stability="alpha")
96+
class MultiAgentObservationPreprocessor(ConnectorV2, abc.ABC):
97+
"""Env-to-module connector preprocessing the most recent multi-agent observation.
98+
99+
The observation is always a dict of individual agents' observations.
100+
101+
This is a convenience class that simplifies the writing of few-step preprocessor
102+
connectors.
103+
104+
Users must implement the `preprocess()` method, which simplifies the usual procedure
105+
of extracting some data from a list of episodes and adding it to the batch to a mere
106+
"old-observation --transform--> return new-observation" step.
107+
"""
108+
109+
@override(ConnectorV2)
110+
def recompute_output_observation_space(
111+
self,
112+
input_observation_space: gym.Space,
113+
input_action_space: gym.Space,
114+
) -> gym.Space:
115+
# Users should override this method only in case the
116+
# `MultiAgentObservationPreprocessor` changes the observation space of the
117+
# pipeline. In this case, return the new observation space based on the
118+
# incoming one (`input_observation_space`).
119+
return super().recompute_output_observation_space(
120+
input_observation_space, input_action_space
121+
)
122+
123+
@abc.abstractmethod
124+
def preprocess(self, observations, episode: MultiAgentEpisode):
125+
"""Override to implement the preprocessing logic.
126+
127+
Args:
128+
observations: An observation dict containing each stepping agents'
129+
(non-batched) observation to be preprocessed by this connector.
130+
episode: The MultiAgentEpisode instance, where the `observation` dict
131+
originated from.
132+
133+
Returns:
134+
The new multi-agent observation dict after `observations` has been
135+
preprocessed.
136+
"""
137+
138+
@override(ConnectorV2)
139+
def __call__(
140+
self,
141+
*,
142+
rl_module: RLModule,
143+
batch: Dict[str, Any],
144+
episodes: List[EpisodeType],
145+
explore: Optional[bool] = None,
146+
persistent_data: Optional[dict] = None,
147+
**kwargs,
148+
) -> Any:
149+
# We process and then replace observations inside the episodes directly.
150+
# Thus, all following connectors will only see and operate on the already
151+
# processed observation (w/o having access anymore to the original
152+
# observations).
153+
for ma_episode in episodes:
154+
observations = ma_episode.get_observations(-1)
155+
156+
# Process the observation and write the new observation back into the
157+
# episode.
158+
new_observation = self.preprocess(
159+
observations=observations,
160+
episode=ma_episode,
161+
)
162+
# TODO (sven): Implement set_observations API for multi-agent episodes.
163+
# For now, we'll hack it through the single agent APIs.
164+
# ma_episode.set_observations(at_indices=-1, new_data=new_observation)
165+
for agent_id, obs in new_observation.items():
166+
ma_episode.agent_episodes[agent_id].set_observations(
167+
at_indices=-1,
168+
new_data=obs,
169+
)
170+
# We set the Episode's observation space to ours so that we can safely
171+
# set the last obs to the new value (without causing a space mismatch
172+
# error).
173+
ma_episode.observation_space = self.observation_space
174+
175+
# Leave `batch` as is. RLlib's default connector automatically populates
176+
# the OBS column therein from the episodes' now transformed observations.
177+
return batch
178+
179+
180+
# Backward compatibility
181+
ObservationPreprocessor = SingleAgentObservationPreprocessor

rllib/env/single_agent_episode.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,6 @@ def add_env_reset(
373373

374374
infos = infos or {}
375375

376-
if self.observation_space is not None:
377-
assert self.observation_space.contains(observation), (
378-
f"`observation` {observation} does NOT fit SingleAgentEpisode's "
379-
f"observation_space: {self.observation_space}!"
380-
)
381-
382376
self.observations.append(observation)
383377
self.infos.append(infos)
384378

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from typing import Any
2+
3+
import gymnasium as gym
4+
import numpy as np
5+
6+
from ray.rllib.connectors.env_to_module.observation_preprocessor import (
7+
MultiAgentObservationPreprocessor,
8+
)
9+
from ray.rllib.utils.annotations import override
10+
11+
12+
class AddOtherAgentsRowIndexToXYPos(MultiAgentObservationPreprocessor):
13+
"""Adds other agent's row index to an x/y-observation for an agent.
14+
15+
Run this connector with this env:
16+
:py:class:`~ray.rllib.examples.env.classes.multi_agent.double_row_corridor_env.DoubleRowCorridorEnv` # noqa
17+
18+
In this env, 2 agents walk around in a grid-world and must, each separately, reach
19+
their individual goal position to receive a final reward. However, if they collide
20+
while search for these goal positions, another larger reward is given to both
21+
agents. Thus, optimal policies aim at seeking the other agent first, and only then
22+
proceeding to their agent's goal position.
23+
24+
Each agents' observation space is a 2-tuple encoding the x/y position
25+
(x=row, y=column).
26+
This connector converts these observations to:
27+
A dict for `agent_0` of structure:
28+
{
29+
"agent": Discrete index encoding the position of the agent,
30+
"other_agent_row": Discrete(2), indicating whether the other agent is in row 0
31+
or row 1,
32+
}
33+
And a 3-tuple for `agent_1`, encoding the x/y position of `agent_1` plus the row
34+
index (0 or 1) of `agent_0`.
35+
36+
Note that the row information for the respective other agent, which this connector
37+
provides, is needed for learning an optimal policy for any of the agents, because
38+
the env rewards the first collision between the two agents. Hence, an agent needs to
39+
have information on which row the respective other agent is currently in, so it can
40+
change to this row and try to collide with this other agent.
41+
"""
42+
43+
@override(MultiAgentObservationPreprocessor)
44+
def recompute_output_observation_space(
45+
self,
46+
input_observation_space,
47+
input_action_space,
48+
) -> gym.Space:
49+
"""Maps the original (input) observation space to the new one.
50+
51+
Original observation space is `Dict({agent_n: Box(4,), ...})`.
52+
Converts the space for `self.agent` into information specific to this agent,
53+
plus the current row of the respective other agent.
54+
Output observation space is then:
55+
`Dict({`agent_n`: Dict(Discrete, Discrete), ...}), where the 1st Discrete
56+
is the position index of the agent and the 2nd Discrete encodes the current row
57+
of the other agent (0 or 1). If the other agent is already done with the episode
58+
(has reached its goal state) a special value of 2 is used.
59+
"""
60+
agent_0_space = input_observation_space.spaces["agent_0"]
61+
self._env_corridor_len = agent_0_space.high[1] + 1 # Box.high is inclusive.
62+
# Env has always 2 rows (and `self._env_corridor_len` columns).
63+
num_discrete = int(2 * self._env_corridor_len)
64+
spaces = {
65+
"agent_0": gym.spaces.Dict(
66+
{
67+
# Exact position of this agent (as an int index).
68+
"agent": gym.spaces.Discrete(num_discrete),
69+
# Row (0 or 1) of other agent. Or 2, if other agent is already done.
70+
"other_agent_row": gym.spaces.Discrete(3),
71+
}
72+
),
73+
"agent_1": gym.spaces.Box(
74+
0,
75+
agent_0_space.high[1], # 1=column
76+
shape=(3,),
77+
dtype=np.float32,
78+
),
79+
}
80+
return gym.spaces.Dict(spaces)
81+
82+
@override(MultiAgentObservationPreprocessor)
83+
def preprocess(self, observations, episode) -> Any:
84+
# Observations: dict of keys "agent_0" and "agent_1", mapping to the respective
85+
# x/y positions of these agents (x=row, y=col).
86+
# For example: [1.0, 4.0] means the agent is in row 1 and column 4.
87+
88+
new_obs = {}
89+
# 2=agent is already done
90+
row_agent_0 = observations.get("agent_0", [2])[0]
91+
row_agent_1 = observations.get("agent_1", [2])[0]
92+
93+
if "agent_0" in observations:
94+
# Compute `agent_0` and `agent_1` enhanced observation.
95+
index_obs_agent_0 = (
96+
observations["agent_0"][0] * self._env_corridor_len
97+
+ observations["agent_0"][1]
98+
)
99+
new_obs["agent_0"] = {
100+
"agent": index_obs_agent_0,
101+
"other_agent_row": row_agent_1,
102+
}
103+
104+
if "agent_1" in observations:
105+
new_obs["agent_1"] = np.array(
106+
[
107+
observations["agent_1"][0],
108+
observations["agent_1"][1],
109+
row_agent_0,
110+
],
111+
dtype=np.float32,
112+
)
113+
114+
return new_obs

0 commit comments

Comments
 (0)