Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gymnasium_robotics/envs/fetch/fetch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ def _reset_sim(self):
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
self.data.qacc_warmstart[:] = np.copy(self.initial_qacc_warmstart)
self.data.ctrl[:] = np.copy(self.initial_ctrl)
self.data.mocap_pos[:] = np.copy(self.initial_mocap_pos)
self.data.mocap_quat[:] = np.copy(self.initial_mocap_quat)
if self.model.na != 0:
self.data.act[:] = None

Expand Down
4 changes: 4 additions & 0 deletions gymnasium_robotics/envs/robot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def _initialize_simulation(self):
self.initial_time = self.data.time
self.initial_qpos = np.copy(self.data.qpos)
self.initial_qvel = np.copy(self.data.qvel)
self.initial_ctrl = np.copy(self.data.ctrl)
self.initial_qacc_warmstart = np.copy(self.data.qacc_warmstart)
self.initial_mocap_pos = np.copy(self.data.mocap_pos)
self.initial_mocap_quat = np.copy(self.data.mocap_quat)

def _reset_sim(self):
self.data.time = self.initial_time
Expand Down
93 changes: 93 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,99 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
env_2.close()


@pytest.mark.parametrize(
"env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
)
def test_same_env_determinism_rollout(env_spec: EnvSpec):
"""Run two rollouts with a single environment and assert equality.

This test runs two rollouts of NUM_STEPS steps with one environment
reset with the same seed and asserts that:

- observations after the reset are the same
- same actions are sampled by the environment
- observations are contained in the observation space
- obs, rew, terminated, truncated and info are equals between the two rollouts
"""
# Don't check rollout equality if it's a nondeterministic environment.
if env_spec.nondeterministic is True:
return

env = env_spec.make(disable_env_checker=True)

rollout_1 = {
"observations": [],
"actions": [],
"rewards": [],
"terminated": [],
"truncated": [],
"infos": [],
}
rollout_2 = {
"observations": [],
"actions": [],
"rewards": [],
"terminated": [],
"truncated": [],
"infos": [],
}

# Run two rollouts of the same environment instance
for rollout in [rollout_1, rollout_2]:
# Reset the environment with the same seed for both rollouts
obs, info = env.reset(seed=SEED)
env.action_space.seed(SEED)
rollout["observations"].append(obs)
rollout["infos"].append(info)

for time_step in range(NUM_STEPS):
action = env.action_space.sample()

obs, rew, terminated, truncated, info = env.step(action)
rollout["observations"].append(obs)
rollout["actions"].append(action)
rollout["rewards"].append(rew)
rollout["terminated"].append(terminated)
rollout["truncated"].append(truncated)
rollout["infos"].append(info)
if terminated or truncated:
env.reset(seed=SEED)

for time_step, (obs_1, obs_2) in enumerate(
zip(rollout_1["observations"], rollout_2["observations"])
):
# -1 because of the initial observation stored on reset
time_step = "initial" if time_step == 0 else time_step - 1
assert_equals(obs_1, obs_2, f"[{time_step}] ")
assert env.observation_space.contains(
obs_1
) # obs_2 verified by previous assertion
for time_step, (rew_1, rew_2) in enumerate(
zip(rollout_1["rewards"], rollout_2["rewards"])
):
assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
for time_step, (terminated_1, terminated_2) in enumerate(
zip(rollout_1["terminated"], rollout_2["terminated"])
):
assert (
terminated_1 == terminated_2
), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
for time_step, (truncated_1, truncated_2) in enumerate(
zip(rollout_1["truncated"], rollout_2["truncated"])
):
assert (
truncated_1 == truncated_2
), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
for time_step, (info_1, info_2) in enumerate(
zip(rollout_1["infos"], rollout_2["infos"])
):
# -1 because of the initial info stored on reset
time_step = "initial" if time_step == 0 else time_step - 1
assert_equals(info_1, info_2, f"[{time_step}] ")

env.close()


@pytest.mark.parametrize(
"spec", non_mujoco_py_env_specs, ids=[spec.id for spec in non_mujoco_py_env_specs]
)
Expand Down