Skip to content

Conversation

Dahoas
Copy link
Collaborator

@Dahoas Dahoas commented Oct 13, 2022

Implemented BranchModel class to support multi-headed hydra type models. Also added adaptive kl controller.

Achieves 4x speedup for training on GPT2-mediuma and 10x speedup for training on GPTj and halves memory footprint.

I also added unittests.

model_type : "AcceleratePPOModel" # Name of accelerate model type to load
device : "cuda" # Train device
num_layers_unfrozen : -1 # Number of bottom layers to freeze during training
num_layers_unfrozen : 2 # Number of bottom layers to freeze during training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing this in the default config.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below

cliprange : 0.2 # clip range
cliprange_value : 0.2 # clip range
vf_coef : 0.2 # value term weight
vf_coef : 2.3 # value term weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found these parameters work a lot better for quickly checking whether reward is increasing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. sounds good.

@@ -0,0 +1,52 @@
model:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config would be great for a CI.


# Cell

class ModelBranch(PreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a high level overview comment of how this class works.

@@ -0,0 +1,52 @@
import unittest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this is a useful class but I think we should be handling unit in a separate PR....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it in this merge because it tests the ModelBranch implementation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. It needs a lot of work. lets chat later.

@LouisCastricato LouisCastricato merged commit d90dc88 into master Oct 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants