Skip to content

Commit 5dde237

Browse files
revert maze reward function (#158)
* revert maze reward function * Update maze_v4.py
1 parent 3ab81e2 commit 5dde237

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

gymnasium_robotics/envs/maze/maze.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,11 @@ def add_xy_position_noise(self, xy_pos: np.ndarray) -> np.ndarray:
275275
def compute_reward(
276276
self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info
277277
) -> float:
278-
d = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
278+
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
279279
if self.reward_type == "dense":
280-
return np.exp(-d)
280+
return np.exp(-distance)
281281
elif self.reward_type == "sparse":
282-
return -(d > 0.45).astype(np.float32)
282+
return (distance <= 0.45).astype(np.float64)
283283

284284
def compute_terminated(
285285
self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info

gymnasium_robotics/envs/maze/maze_v4.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,11 @@ def add_xy_position_noise(self, xy_pos: np.ndarray) -> np.ndarray:
355355
def compute_reward(
356356
self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info
357357
) -> float:
358-
d = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
358+
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
359359
if self.reward_type == "dense":
360-
return np.exp(-d)
360+
return np.exp(-distance)
361361
elif self.reward_type == "sparse":
362-
return -(d > 0.45).astype(np.float32)
362+
return (distance <= 0.45).astype(np.float64)
363363

364364
def compute_terminated(
365365
self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info

0 commit comments

Comments
 (0)