Skip to content

Commit 457169e

Browse files
EnliteAI Botenliteai
authored andcommitted
Fix: Make actor ID optional for single-network torch policies
1 parent 8b26b32 commit 457169e

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

maze/core/agent/torch_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,17 @@ def compute_logits_dict(self, observation: Any, actor_id: ActorIDType = None) ->
134134
obs_t = convert_to_torch(observation, device=self._device, cast=None, in_place=True)
135135
return self.network_for(actor_id)(obs_t)
136136

137-
def network_for(self, actor_id: ActorIDType) -> nn.Module:
137+
def network_for(self, actor_id: Optional[ActorIDType]) -> nn.Module:
138138
"""Helper function for returning a network for the given policy ID (using either just the sub-step ID
139139
or the full Actor ID as key, depending on the separated agent networks mode.
140140
141141
:param actor_id: Actor ID to get a network for
142142
:return: Network corresponding to the given policy ID.
143143
"""
144+
if actor_id is None:
145+
assert len(self.networks) == 1, "multiple networks are available, please specify the actor ID explicitly"
146+
return list(self.networks.values())[0]
147+
144148
network_key = actor_id if actor_id[0] in self.substeps_with_separate_agent_nets else actor_id[0]
145149
return self.networks[network_key]
146150

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Torch policy mechanics tests."""
2+
3+
from maze.test.shared_test_utils.helper_functions import build_dummy_maze_env, \
4+
flatten_concat_probabilistic_policy_for_env
5+
6+
7+
def test_actor_id_is_optional_for_single_network_policies():
8+
env = build_dummy_maze_env()
9+
policy = flatten_concat_probabilistic_policy_for_env(env)
10+
11+
obs = env.reset()
12+
action = policy.compute_action(obs) # No actor ID provided
13+
assert action in env.action_space

0 commit comments

Comments
 (0)