Skip to content

Commit caf0744

Browse files
xeviknalziritrion
andauthored
Add visualization skills: evaluation mode (#12)
Co-authored-by: ziritrion <[email protected]>
1 parent 472e1f7 commit caf0744

File tree

3 files changed

+61
-10
lines changed

3 files changed

+61
-10
lines changed

environment.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import gym
22
from wrappers.frame_skipper import FrameSkipper
3-
from gym.wrappers import FrameStack, GrayScaleObservation
3+
from gym.wrappers import FrameStack, GrayScaleObservation, Monitor
44

55

66
class CarRacingEnv:
77

8-
def __init__(self, device, stack_frames=4):
8+
def __init__(self, device, stack_frames=4, train=False):
99
super().__init__()
1010
self.total_rew = 0
1111
self.state = None
1212
self.done = False
1313
self.device = device
14+
self.train = train
1415

1516
self.env = gym.make("CarRacing-v0")
17+
if not train:
18+
self.env = Monitor(self.env, './video', force=True)
1619
self.env = GrayScaleObservation(self.env)
1720
self.env = FrameStack(self.env, stack_frames)
1821
self.env = FrameSkipper(self.env, 4)
@@ -31,6 +34,9 @@ def spec(self):
3134
return self.env.spec
3235

3336
def close(self):
34-
self.close()
37+
self.env.close()
38+
39+
def render(self):
40+
self.env.render()
3541

3642

main.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
11
import torch
22

3+
import helpers
34
from environment import CarRacingEnv
45
from trainer import Trainer
6+
from runner import Runner
57

68
from pyvirtualdisplay import Display
7-
display = Display(visible=0, size=(1400, 900))
8-
display.start()
99

1010
# if gpu is to be used
1111
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1212

1313
if __name__ == "__main__":
1414
hyperparams = {
15-
'num_episodes': 20000, # Number of training episodes
15+
'num_episodes': 40000, # Number of training episodes
1616
'lr': 1e-2, # Learning rate
1717
'gamma': 0.99, # Discount rate
1818
'log_interval': 5, # controls how often we log progress
1919
'stack_frames': 4,
2020
'device': device,
21-
'params_path': './params/policy-params.dl'
21+
'params_path': './params/policy-params.dl',
22+
'train': True
2223
}
2324

24-
env = CarRacingEnv(device, hyperparams['stack_frames'])
25-
trainer = Trainer(env, hyperparams)
26-
trainer.train()
25+
env = CarRacingEnv(device, hyperparams['stack_frames'], hyperparams['train'])
26+
helpers.display_start()
27+
if(hyperparams['train']):
28+
trainer = Trainer(env, hyperparams)
29+
trainer.train()
30+
else:
31+
runner = Runner(env, hyperparams)
32+
runner.run()
2733

runner.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
import numpy as np
3+
4+
from policy import Policy
5+
from actions import available_actions
6+
7+
class Runner:
8+
def __init__(self, env, config):
9+
super().__init__()
10+
self.env = env
11+
self.config = config
12+
self.input_channels = config['stack_frames']
13+
#self.device = config['device']
14+
self.policy = Policy(self.input_channels, len(available_actions))
15+
self.policy.load_checkpoint(config['params_path'])
16+
17+
def select_action(self, state):
18+
if state is None: # First state is always None
19+
# Adding the starting signal as a 0's tensor
20+
state = np.zeros((self.input_channels, 96, 96))
21+
else:
22+
state = np.asarray(state)
23+
state = torch.from_numpy(state).float().unsqueeze(0)
24+
probs = self.policy(state)
25+
# We pick the action from a sample of the probabilities
26+
# It prevents the model from picking always the same action
27+
m = torch.distributions.Categorical(probs)
28+
action = m.sample()
29+
return available_actions[action.item()]
30+
31+
def run(self):
32+
state, done, total_rew = self.env.reset(), False, 0
33+
while not done:
34+
self.env.render()
35+
action = self.select_action(state)
36+
state, rew, done, info = self.env.step(action)
37+
total_rew += rew
38+
print('Cumulative reward:', total_rew)
39+
self.env.close()

0 commit comments

Comments
 (0)