-
Notifications
You must be signed in to change notification settings - Fork 482
Refactor RL model wrapper into a trainer
module
#144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
.. _trainers: | ||
|
||
RL Trainers | ||
******************* | ||
|
||
RL Trainers are what you're training with trlX. Currently, we support PPO and ILQL. | ||
Note that new trainers must be registered with ``trlx.trainer.register_trainer``. | ||
|
||
**General** | ||
|
||
.. autoclass:: trlx.trainer.BaseRLTrainer | ||
:members: | ||
|
||
.. autoclass:: trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer | ||
:members: | ||
|
||
**PPO** | ||
|
||
.. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer | ||
:members: | ||
|
||
.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMWithValueHead | ||
:members: | ||
|
||
.. autoclass:: trlx.trainer.nn.ppo_models.GPTModelBranch | ||
:members: | ||
|
||
.. autoclass:: trlx.trainer.nn.ppo_models.OPTModelBranch | ||
:members: | ||
|
||
.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMHydraWithValueHead | ||
:members: | ||
|
||
**ILQL** | ||
|
||
.. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer | ||
:members: | ||
|
||
.. autoclass:: trlx.trainer.nn.ilql_models.CausalLMWithValueHeads | ||
:members: |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,12 +59,12 @@ def main(hparams={}): | |
dataset = DSLDataset() | ||
train_prompts = list(dataset.load_datapoints(split="train"))[:1000] | ||
|
||
model = trlx.train( | ||
trainer = trlx.train( | ||
reward_fn=reward_fn, | ||
prompts=train_prompts, | ||
config=config, | ||
) | ||
model.save_pretrained("dataset/trained_model") | ||
trainer.save_pretrained("dataset/trained_model") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do the model types we use support save_pretrained? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait I don't think they do, or at least ppo doesn't. The base ppo model is just an nn.Module (not pretrained). It seems actually very annoying to save new model architectures in a huggingface format. We'll probably have to write a new config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohhhh ok wait that's weird then can we just add a save pretrained function to PPO haha There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But anyway saving doesn't have anything to do with this pr so I think it's fine for now. |
||
|
||
|
||
if __name__ == "__main__": | ||
|
Uh oh!
There was an error while loading. Please reload this page.