Skip to content

Commit d8645b2

Browse files
author
Aleksei Petrenko
authored
Small changes for DexPBT paper. 1) load env state from checkpoint in player (i.e. if we train with curriculum we want to load the same state). 2) Minor speedup disabling validate_args in distributions 3) added another way to stop training max_env_steps (#204)
1 parent f3e9c7f commit d8645b2

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

rl_games/algos_torch/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def forward(self, input_dict):
195195
prev_actions = input_dict.get('prev_actions', None)
196196
input_dict['obs'] = self.norm_obs(input_dict['obs'])
197197
mu, sigma, value, states = self.a2c_network(input_dict)
198-
distr = torch.distributions.Normal(mu, sigma)
198+
distr = torch.distributions.Normal(mu, sigma, validate_args=False)
199199

200200
if is_train:
201201
entropy = distr.entropy().sum(dim=-1)
@@ -246,7 +246,7 @@ def forward(self, input_dict):
246246
input_dict['obs'] = self.norm_obs(input_dict['obs'])
247247
mu, logstd, value, states = self.a2c_network(input_dict)
248248
sigma = torch.exp(logstd)
249-
distr = torch.distributions.Normal(mu, sigma)
249+
distr = torch.distributions.Normal(mu, sigma, validate_args=False)
250250
if is_train:
251251
entropy = distr.entropy().sum(dim=-1)
252252
prev_neglogp = self.neglogp(prev_actions, mu, sigma, logstd)

rl_games/algos_torch/players.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def restore(self, fn):
7474
if self.normalize_input and 'running_mean_std' in checkpoint:
7575
self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])
7676

77+
env_state = checkpoint.get('env_state', None)
78+
if self.env is not None and env_state is not None:
79+
self.env.set_env_state(env_state)
80+
7781
def reset(self):
7882
self.init_rnn()
7983

@@ -172,6 +176,10 @@ def restore(self, fn):
172176
if self.normalize_input and 'running_mean_std' in checkpoint:
173177
self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])
174178

179+
env_state = checkpoint.get('env_state', None)
180+
if self.env is not None and env_state is not None:
181+
self.env.set_env_state(env_state)
182+
175183
def reset(self):
176184
self.init_rnn()
177185

@@ -210,6 +218,10 @@ def restore(self, fn):
210218
if self.normalize_input and 'running_mean_std' in checkpoint:
211219
self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])
212220

221+
env_state = checkpoint.get('env_state', None)
222+
if self.env is not None and env_state is not None:
223+
self.env.set_env_state(env_state)
224+
213225
def get_action(self, obs, is_deterministic=False):
214226
if self.has_batch_dimension == False:
215227
obs = unsqueeze_obs(obs)
@@ -221,4 +233,4 @@ def get_action(self, obs, is_deterministic=False):
221233
return actions
222234

223235
def reset(self):
224-
pass
236+
pass

rl_games/algos_torch/torch_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def load_checkpoint(filename):
8484
return state
8585

8686
def parameterized_truncated_normal(uniform, mu, sigma, a, b):
87-
normal = torch.distributions.normal.Normal(0, 1)
87+
normal = torch.distributions.normal.Normal(0, 1, validate_args=False)
8888

8989
alpha = (a - mu) / sigma
9090
beta = (b - mu) / sigma

rl_games/common/a2c_common.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(self, base_name, params):
134134

135135
self.ppo = config.get('ppo', True)
136136
self.max_epochs = self.config.get('max_epochs', 1e6)
137+
self.max_frames = self.config.get('max_frames', 1e10)
137138

138139
self.is_adaptive_lr = config['lr_schedule'] == 'adaptive'
139140
self.linear_lr = config['lr_schedule'] == 'linear'
@@ -932,7 +933,7 @@ def train(self):
932933
fps_step = curr_frames / step_time
933934
fps_step_inference = curr_frames / scaled_play_time
934935
fps_total = curr_frames / scaled_time
935-
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}')
936+
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}')
936937

937938
self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)
938939

@@ -974,14 +975,12 @@ def train(self):
974975
self.save(os.path.join(self.nn_dir, checkpoint_name))
975976
should_exit = True
976977

977-
if epoch_num >= self.max_epochs:
978+
if epoch_num >= self.max_epochs or self.frame >= self.max_frames:
978979
if self.game_rewards.current_size == 0:
979980
print('WARNING: Max epochs reached before any env terminated at least once')
980981
mean_rewards = -np.inf
981982

982-
self.save(os.path.join(self.nn_dir,
983-
'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(
984-
mean_rewards)))
983+
self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards)))
985984
print('MAX EPOCHS NUM!')
986985
should_exit = True
987986
update_time = 0
@@ -1191,7 +1190,7 @@ def train(self):
11911190
fps_step = curr_frames / step_time
11921191
fps_step_inference = curr_frames / scaled_play_time
11931192
fps_total = curr_frames / scaled_time
1194-
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}')
1193+
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}')
11951194

11961195
self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)
11971196
if len(b_losses) > 0:
@@ -1235,11 +1234,11 @@ def train(self):
12351234
self.save(os.path.join(self.nn_dir, checkpoint_name))
12361235
should_exit = True
12371236

1238-
if epoch_num >= self.max_epochs:
1237+
if epoch_num >= self.max_epochs or self.frame >= self.max_frames:
12391238
if self.game_rewards.current_size == 0:
12401239
print('WARNING: Max epochs reached before any env terminated at least once')
12411240
mean_rewards = -np.inf
1242-
self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards)))
1241+
self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards).replace('[', '_').replace(']', '_')))
12431242
print('MAX EPOCHS NUM!')
12441243
should_exit = True
12451244

@@ -1253,4 +1252,4 @@ def train(self):
12531252
return self.last_mean_rewards, epoch_num
12541253

12551254
if should_exit:
1256-
return self.last_mean_rewards, epoch_num
1255+
return self.last_mean_rewards, epoch_num

0 commit comments

Comments
 (0)