Skip to content

Commit 081e762

Browse files
author
EnliteAI Bot
committed
RL-2109: Make loading the model from input dir optional
(Issue RL-2109 - Load experiment params in rollout)
1 parent 4d69e5e commit 081e762

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

maze/conf/conf_rollout.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ seeding:
4242
# Optionally the environment or/and the wrappers config can be used from the input_dir.
4343
# Example for using the environment config from the input_dir:
4444
#use_input_dir_config:
45+
# use_input_dir_model: True
4546
# use_input_dir_env: True
4647
# use_input_dir_wrappers: False
4748
use_input_dir_config: ~

maze/core/rollout/rollout_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ def setup(self, cfg: DictConfig) -> None:
9090
if 'use_input_dir_wrappers' in cfg['use_input_dir_config'] and cfg['use_input_dir_config']['use_input_dir_wrappers']:
9191
cfg.wrappers = input_dir_config.wrappers
9292

93-
cfg.model = input_dir_config.model
94-
cfg.policy.model = input_dir_config.model
93+
if 'use_input_dir_model' in cfg['use_input_dir_config'] and cfg['use_input_dir_config']['use_input_dir_model']:
94+
cfg.model = input_dir_config.model
95+
cfg.policy.model = input_dir_config.model
9596

9697
else:
9798
raise FileNotFoundError(f'Config file {config_path} not found')

maze/test/core/rollout/test_rollout_from_input_dir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,16 @@ def experiment_out_dir():
2929

3030
@pytest.mark.parametrize("use_input_dir_env", [True, False])
3131
@pytest.mark.parametrize("use_input_dir_wrappers", [True, False])
32-
def test_train_and_rollout(experiment_out_dir, use_input_dir_env, use_input_dir_wrappers):
32+
@pytest.mark.parametrize("use_input_dir_model", [True, False])
33+
def test_train_and_rollout(experiment_out_dir, use_input_dir_env, use_input_dir_wrappers, use_input_dir_model):
3334
"""Test loading config from the experiment output directory in the rollout run"""
3435
rollout_hydra_overrides = {
3536
"runner": "sequential",
3637
"runner.n_episodes": "2",
3738
"policy": "torch_policy",
3839
"+use_input_dir_config.use_input_dir_env": use_input_dir_env,
3940
"+use_input_dir_config.use_input_dir_wrappers": use_input_dir_wrappers,
41+
"+use_input_dir_config.use_input_dir_model": use_input_dir_model,
4042
"input_dir": experiment_out_dir,
4143
}
4244

0 commit comments

Comments
 (0)