-
Notifications
You must be signed in to change notification settings - Fork 482
Refactor RL model wrapper into a trainer
module
#144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor RL model wrapper into a trainer
module
#144
Conversation
I am strongly in favor of this refactor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Will merge if there are no further changes
config=config, | ||
) | ||
model.save_pretrained("dataset/trained_model") | ||
trainer.save_pretrained("dataset/trained_model") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the model types we use support save_pretrained?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait I don't think they do, or at least ppo doesn't. The base ppo model is just an nn.Module (not pretrained). It seems actually very annoying to save new model architectures in a huggingface format. We'll probably have to write a new config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohhhh ok wait that's weird then can we just add a save pretrained function to PPO haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But anyway saving doesn't have anything to do with this pr so I think it's fine for now.
This PR refactors the RL model wrappers into "trainer" wrappers. The term "model" has semantic overloading throughout the codebase. A specific point of confusion is the
{Type}RLModel
s, which do not only contain models but also wrap around optimizers, schedulers, and other auxiliary data structures required for RL training.This refactor was briefly discussed in the CarperAI Discord with @cat-state. I am leaving it here as a reminder and for others to chime in.