@@ -134,6 +134,7 @@ def __init__(self, base_name, params):
134
134
135
135
self .ppo = config .get ('ppo' , True )
136
136
self .max_epochs = self .config .get ('max_epochs' , 1e6 )
137
+ self .max_frames = self .config .get ('max_frames' , 1e10 )
137
138
138
139
self .is_adaptive_lr = config ['lr_schedule' ] == 'adaptive'
139
140
self .linear_lr = config ['lr_schedule' ] == 'linear'
@@ -932,7 +933,7 @@ def train(self):
932
933
fps_step = curr_frames / step_time
933
934
fps_step_inference = curr_frames / scaled_play_time
934
935
fps_total = curr_frames / scaled_time
935
- print (f'fps step: { fps_step :.0f} fps step and policy inference: { fps_step_inference :.0f} fps total: { fps_total :.0f} epoch: { epoch_num } /{ self .max_epochs } ' )
936
+ print (f'fps step: { fps_step :.0f} fps step and policy inference: { fps_step_inference :.0f} fps total: { fps_total :.0f} epoch: { epoch_num } /{ self .max_epochs } frame: { self . frame } / { self . max_frames } ' )
936
937
937
938
self .write_stats (total_time , epoch_num , step_time , play_time , update_time , a_losses , c_losses , entropies , kls , last_lr , lr_mul , frame , scaled_time , scaled_play_time , curr_frames )
938
939
@@ -974,14 +975,12 @@ def train(self):
974
975
self .save (os .path .join (self .nn_dir , checkpoint_name ))
975
976
should_exit = True
976
977
977
- if epoch_num >= self .max_epochs :
978
+ if epoch_num >= self .max_epochs or self . frame >= self . max_frames :
978
979
if self .game_rewards .current_size == 0 :
979
980
print ('WARNING: Max epochs reached before any env terminated at least once' )
980
981
mean_rewards = - np .inf
981
982
982
- self .save (os .path .join (self .nn_dir ,
983
- 'last_' + self .config ['name' ] + 'ep' + str (epoch_num ) + 'rew' + str (
984
- mean_rewards )))
983
+ self .save (os .path .join (self .nn_dir , 'last_' + self .config ['name' ] + 'ep' + str (epoch_num ) + 'rew' + str (mean_rewards )))
985
984
print ('MAX EPOCHS NUM!' )
986
985
should_exit = True
987
986
update_time = 0
@@ -1191,7 +1190,7 @@ def train(self):
1191
1190
fps_step = curr_frames / step_time
1192
1191
fps_step_inference = curr_frames / scaled_play_time
1193
1192
fps_total = curr_frames / scaled_time
1194
- print (f'fps step: { fps_step :.0f} fps step and policy inference: { fps_step_inference :.0f} fps total: { fps_total :.0f} epoch: { epoch_num } /{ self .max_epochs } ' )
1193
+ print (f'fps step: { fps_step :.0f} fps step and policy inference: { fps_step_inference :.0f} fps total: { fps_total :.0f} epoch: { epoch_num } /{ self .max_epochs } frame: { self . frame } / { self . max_frames } ' )
1195
1194
1196
1195
self .write_stats (total_time , epoch_num , step_time , play_time , update_time , a_losses , c_losses , entropies , kls , last_lr , lr_mul , frame , scaled_time , scaled_play_time , curr_frames )
1197
1196
if len (b_losses ) > 0 :
@@ -1235,11 +1234,11 @@ def train(self):
1235
1234
self .save (os .path .join (self .nn_dir , checkpoint_name ))
1236
1235
should_exit = True
1237
1236
1238
- if epoch_num >= self .max_epochs :
1237
+ if epoch_num >= self .max_epochs or self . frame >= self . max_frames :
1239
1238
if self .game_rewards .current_size == 0 :
1240
1239
print ('WARNING: Max epochs reached before any env terminated at least once' )
1241
1240
mean_rewards = - np .inf
1242
- self .save (os .path .join (self .nn_dir , 'last_' + self .config ['name' ] + 'ep' + str (epoch_num ) + 'rew' + str (mean_rewards )))
1241
+ self .save (os .path .join (self .nn_dir , 'last_' + self .config ['name' ] + 'ep' + str (epoch_num ) + 'rew' + str (mean_rewards ). replace ( '[' , '_' ). replace ( ']' , '_' ) ))
1243
1242
print ('MAX EPOCHS NUM!' )
1244
1243
should_exit = True
1245
1244
@@ -1253,4 +1252,4 @@ def train(self):
1253
1252
return self .last_mean_rewards , epoch_num
1254
1253
1255
1254
if should_exit :
1256
- return self .last_mean_rewards , epoch_num
1255
+ return self .last_mean_rewards , epoch_num
0 commit comments