Skip to content

Commit fecd8cd

Browse files
committed
Adding checkpoints: saving the model every n episodes
1 parent 73148e8 commit fecd8cd

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pyvirtualdisplay import Display
2+
import torch
23
import glob
34
import io
45
import base64
@@ -23,3 +24,7 @@ def show_video():
2324
def display_start():
2425
display = Display(visible=0, size=(1400, 900))
2526
display.start()
27+
28+
29+
def save_model(model, path):
30+
torch.save(model.state_dict(), path)

params/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from policy import Policy
44
from actions import Actions
5+
from helpers import save_model
56

67

78
class Trainer:
@@ -80,6 +81,8 @@ def train(self):
8081
if i_episode % self.config['log_interval'] == 0:
8182
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
8283
i_episode, ep_reward, running_reward))
84+
save_model(self.policy, './params/policy-params.dl')
85+
8386
if running_reward > self.env.spec().reward_threshold:
8487
print("Solved!")
8588
break

0 commit comments

Comments
 (0)