File tree Expand file tree Collapse file tree 3 files changed +9
-0
lines changed Expand file tree Collapse file tree 3 files changed +9
-0
lines changed Original file line number Diff line number Diff line change 1
1
from pyvirtualdisplay import Display
2
+ import torch
2
3
import glob
3
4
import io
4
5
import base64
@@ -23,3 +24,7 @@ def show_video():
23
24
def display_start ():
24
25
display = Display (visible = 0 , size = (1400 , 900 ))
25
26
display .start ()
27
+
28
+
29
+ def save_model (model , path ):
30
+ torch .save (model .state_dict (), path )
Original file line number Diff line number Diff line change
1
+
Original file line number Diff line number Diff line change 2
2
import numpy as np
3
3
from policy import Policy
4
4
from actions import Actions
5
+ from helpers import save_model
5
6
6
7
7
8
class Trainer :
@@ -80,6 +81,8 @@ def train(self):
80
81
if i_episode % self .config ['log_interval' ] == 0 :
81
82
print ('Episode {}\t Last reward: {:.2f}\t Average reward: {:.2f}' .format (
82
83
i_episode , ep_reward , running_reward ))
84
+ save_model (self .policy , './params/policy-params.dl' )
85
+
83
86
if running_reward > self .env .spec ().reward_threshold :
84
87
print ("Solved!" )
85
88
break
You can’t perform that action at this time.
0 commit comments