Skip to content

Conversation

xinchen-yang
Copy link
Contributor

No description provided.

Copy link
Owner

@RyanNavillus RyanNavillus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Left some suggestions

idx = info.index(item)
multiprocessing_sync_wrapper_envs = envs.venv.venv.envs # extract envs of class MultiProcessingSyncWrapper from envs of class VecNormalize, which has access to task id
episode_task = multiprocessing_sync_wrapper_envs[idx]._latest_task
curriculum.update_on_episode(item["episode"]["r"], item["episode"]["l"], episode_task, args.env_id)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't necessary, the curriculum sync wrapper will call this automatically with the correct data.

def __init__(self, task_space: TaskSpace):
"""Initialize the StatRecorder"""

self.write_path = '/Users/allisonyang/Downloads'
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this an initializer argument and let people configure it when they create their curriculum

self.num_tasks = self.task_space.num_tasks

self.records = {task: [] for task in self.tasks}
self.stats = {task: {} for task in self.tasks}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of tracking the full list, track these efficiently. If we train for 10M episodes then it would be impossible to take the average of these lists. You can look up the running mean formulas, I think it's average_mean = ((average_mean * N) + new_mean) / (N+1) or something like that. There might be one for variance too

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, it would be good to provide an option for only saving the past N episodes, so rather than taking the average over all of training, it's an average over the past N episodes. I think some normalization schemes prefer that method because returns change during training. It would be good to provide both options (let the user choose and configure each)

Copy link
Owner

@RyanNavillus RyanNavillus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can simplify this code quite a bit, but it looks more efficient now!

writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_mean", 0, step)
writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_var", 0, step)
writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_mean", 0, step)
writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_var", 0, step)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please simplify this code, you shouldn't need all these if statements and repeated code. Also I think you can ignore the task_names feature for now, it's sort of half implemented, I'll need to fix it at some point

else:
N_past = len(self.records[episode_task])

self.stats[episode_task]['mean_r'] =round((self.stats[episode_task]['mean_r'] * N_past + episode_return) / (N_past + 1), 4)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably shouldn't round the saved values, we should only round them when logging or printing

"l": episode_length,
"env_id": env_id
})
self.records[episode_task] = self.records[episode_task][-self.calc_past_N:]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more efficient to implement this with a separate queue for return, length, and id's https://docs.python.org/3/library/collections.html#collections.deque

writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_mean", 0, step)
writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_var", 0, step)
writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_mean", 0, step)
writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_var", 0, step)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify this if else, theres a lot of repeated code. Also why are we logging 0 if there are no stats? We can probably just skip logging in that case

self.stats = {task: {} for task in self.tasks}

def record(self, episode_return: float, episode_length: int, episode_task, env_id=None):
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use defaultdicts for the stats you can cut out a lot of code. Change:
self.stats = {task: {} for task in self.tasks}
to:
self.stats = {task: defaultdict(float) for task in self.tasks}

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from collections import defaultdict

Copy link
Owner

@RyanNavillus RyanNavillus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for simplifying the code, looks much better!

self.episode_lengths[episode_task].append(episode_length)
self.env_ids[episode_task].append(env_id)

self.stats[episode_task]['mean_r'] = np.mean(list(self.episode_returns[episode_task])[-self.calc_past_N:]) # I am not sure whether there is a more efficient way to slice to deque. I temperorily convert it to a list then slice it, which should cost O(n)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need to slice a deque at all, it should automatically drop elements when it goes past keep_last_n. Check the documentation for it

@RyanNavillus RyanNavillus changed the base branch from main to stat-recorder April 26, 2024 07:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants