-
Notifications
You must be signed in to change notification settings - Fork 16
StatRecorder Class #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: stat-recorder
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Left some suggestions
idx = info.index(item) | ||
multiprocessing_sync_wrapper_envs = envs.venv.venv.envs # extract envs of class MultiProcessingSyncWrapper from envs of class VecNormalize, which has access to task id | ||
episode_task = multiprocessing_sync_wrapper_envs[idx]._latest_task | ||
curriculum.update_on_episode(item["episode"]["r"], item["episode"]["l"], episode_task, args.env_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't necessary, the curriculum sync wrapper will call this automatically with the correct data.
syllabus/core/stat_recorder.py
Outdated
def __init__(self, task_space: TaskSpace): | ||
"""Initialize the StatRecorder""" | ||
|
||
self.write_path = '/Users/allisonyang/Downloads' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this an initializer argument and let people configure it when they create their curriculum
syllabus/core/stat_recorder.py
Outdated
self.num_tasks = self.task_space.num_tasks | ||
|
||
self.records = {task: [] for task in self.tasks} | ||
self.stats = {task: {} for task in self.tasks} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of tracking the full list, track these efficiently. If we train for 10M episodes then it would be impossible to take the average of these lists. You can look up the running mean formulas, I think it's average_mean = ((average_mean * N) + new_mean) / (N+1)
or something like that. There might be one for variance too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, it would be good to provide an option for only saving the past N episodes, so rather than taking the average over all of training, it's an average over the past N episodes. I think some normalization schemes prefer that method because returns change during training. It would be good to provide both options (let the user choose and configure each)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can simplify this code quite a bit, but it looks more efficient now!
syllabus/core/stat_recorder.py
Outdated
writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_mean", 0, step) | ||
writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_var", 0, step) | ||
writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_mean", 0, step) | ||
writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_var", 0, step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please simplify this code, you shouldn't need all these if statements and repeated code. Also I think you can ignore the task_names feature for now, it's sort of half implemented, I'll need to fix it at some point
syllabus/core/stat_recorder.py
Outdated
else: | ||
N_past = len(self.records[episode_task]) | ||
|
||
self.stats[episode_task]['mean_r'] =round((self.stats[episode_task]['mean_r'] * N_past + episode_return) / (N_past + 1), 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably shouldn't round the saved values, we should only round them when logging or printing
syllabus/core/stat_recorder.py
Outdated
"l": episode_length, | ||
"env_id": env_id | ||
}) | ||
self.records[episode_task] = self.records[episode_task][-self.calc_past_N:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be more efficient to implement this with a separate queue for return, length, and id's https://docs.python.org/3/library/collections.html#collections.deque
syllabus/core/stat_recorder.py
Outdated
writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_mean", 0, step) | ||
writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_var", 0, step) | ||
writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_mean", 0, step) | ||
writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_var", 0, step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify this if else, theres a lot of repeated code. Also why are we logging 0 if there are no stats? We can probably just skip logging in that case
self.stats = {task: {} for task in self.tasks} | ||
|
||
def record(self, episode_return: float, episode_length: int, episode_task, env_id=None): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you use defaultdicts for the stats you can cut out a lot of code. Change:
self.stats = {task: {} for task in self.tasks}
to:
self.stats = {task: defaultdict(float) for task in self.tasks}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from collections import defaultdict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for simplifying the code, looks much better!
syllabus/core/stat_recorder.py
Outdated
self.episode_lengths[episode_task].append(episode_length) | ||
self.env_ids[episode_task].append(env_id) | ||
|
||
self.stats[episode_task]['mean_r'] = np.mean(list(self.episode_returns[episode_task])[-self.calc_past_N:]) # I am not sure whether there is a more efficient way to slice to deque. I temperorily convert it to a list then slice it, which should cost O(n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need to slice a deque at all, it should automatically drop elements when it goes past keep_last_n. Check the documentation for it
Co-authored-by: Ryan Sullivan <[email protected]>
…episodes); implemented the log_metrics function of the StatRecorder class for visualization on weights & biases
… for the StatRecorder class
No description provided.